详细讲解pyspark中dsl格式进行大数据开发中的的所有编程情况
在 PySpark 中,DSL(Domain Specific Language)编程主要基于DataFrame API(Python 中 DataFrame 与 Dataset 概念统一),是处理结构化 / 半结构化大数据的核心方式。它通过链式调用 API 方法实现数据加载、清洗、转换、聚合、分析等全流程操作,无需编写 SQL 语句,更贴合 Python 编程习惯。
一、核心基础:环境初始化与 DataFrame 概念
1. 环境准备(初始化 SparkSession)
所有 PySpark DSL 操作的入口是SparkSession
,需先创建会话对象:
from pyspark.sql import SparkSession# 初始化SparkSession
spark = SparkSession.builder \.appName("PySpark-DSL-Example") # 应用名称.master("local[*]") # 本地模式(生产环境无需指定,由集群管理).config("spark.sql.shuffle.partitions", "4") # 调整shuffle分区数(默认200,小数据场景可减小).getOrCreate()# 导入内置函数(常用别名F,简化调用)
from pyspark.sql import functions as F
# 导入数据类型(用于定义schema)
from pyspark.sql.types import *
2. DataFrame 核心概念
- DataFrame:分布式数据集合,以命名列(Column)组织,类似关系型数据库的 “表”,但底层基于 RDD 实现,包含schema(元数据,描述列名和类型)。
- 不可变性:DataFrame 是不可变的,所有转换操作(如
withColumn
、filter
)都会生成新的 DataFrame,原对象不变。 - 惰性执行:转换操作(Transformation)不会立即执行,只有触发行动操作(Action,如
show
、count
、write
)时才会真正计算,优化执行效率。
二、数据加载:创建 DataFrame
PySpark 支持从多种数据源加载数据生成 DataFrame,以下是常见场景:
1. 从文件加载(CSV/JSON/Parquet 等)
(1)CSV 文件(最常用)
# 读取CSV(自动推断schema,适用于测试)
df = spark.read \.option("header", "true") # 第一行为列名.option("inferSchema", "true") # 自动推断列类型(生产环境不推荐,效率低).option("sep", ",") # 分隔符(默认逗号).option("nullValue", "NA") # 指定空值标识.csv("path/to/data.csv") # 文件路径(支持本地/分布式文件系统如HDFS)# 生产环境:手动指定schema(高效且避免类型推断错误)
schema = StructType([StructField("id", IntegerType(), nullable=False), # 非空整数StructField("name", StringType(), nullable=True),StructField("birth_date", StringType(), nullable=True), # 先按字符串读,后续转为日期StructField("salary", DoubleType(), nullable=True)
])df = spark.read \.option("header", "true") \.schema(schema) # 应用手动定义的schema.csv("path/to/data.csv")
(2)JSON 文件
# 读取JSON(支持单行JSON或多行JSON)
df = spark.read \.option("multiline", "true") # 多行JSON(默认false,单行JSON).json("path/to/data.json")
(3)Parquet 文件(Spark 默认存储格式,列式存储,压缩率高)
df = spark.read.parquet("path/to/data.parquet")
2. 从数据库加载(MySQL/PostgreSQL 等)
需添加对应数据库的 JDBC 驱动(如 MySQL 的mysql-connector-java
):
df = spark.read \.format("jdbc") \.option("url", "jdbc:mysql://host:port/db_name") \.option("dbtable", "table_name") # 表名或SQL查询(如"(select * from t where id>100) tmp").option("user", "username") \.option("password", "password") \.option("driver", "com.mysql.cj.jdbc.Driver") # 驱动类名.load()
3. 从 RDD 或集合创建
# 从Python列表创建
data = [("Alice", 25, "F"), ("Bob", 30, "M")]
df = spark.createDataFrame(data, schema=["name", "age", "gender"]) # 手动指定列名# 从RDD创建
rdd = spark.sparkContext.parallelize(data)
df = rdd.toDF(schema=["name", "age", "gender"])
三、基础操作:查看与验证数据
加载数据后,需先验证数据格式和内容,常用 API:
操作 | 功能说明 | 示例代码 |
---|---|---|
printSchema() | 打印 DataFrame 的 schema(列名 + 类型) | df.printSchema() |
show(n, truncate) | 显示前 n 行数据(truncate=False 不截断长字符串) | df.show(5, truncate=False) |
columns | 返回所有列名列表 | print(df.columns) |
dtypes | 返回(列名,类型)列表 | print(df.dtypes) |
count() | 计算总行数(Action 操作) | print(f"总条数: {df.count()}") |
describe(*cols) | 计算列的统计信息(计数、均值、标准差等) | df.describe("age", "salary").show() |
head(n)/take(n) | 获取前 n 行数据(返回 Row 对象列表) | print(df.head(2)) |
四、核心操作:数据转换与清洗
PySpark DSL 提供了丰富的 API 用于数据转换,覆盖列操作、行过滤、类型转换等场景。
1. 列操作
(1)选择列(select
)
# 选择单列
df.select("name").show()# 选择多列
df.select("name", "age").show()# 结合列计算(如年龄+1)
df.select(F.col("name"), # F.col()用于引用列(推荐,支持链式操作)F.col("age"),(F.col("age") + 1).alias("age_plus_1") # alias()重命名列
).show()# 通配符选择(如所有以"col_"开头的列)
df.select(F.col("`col_*`")).show() # 反引号处理特殊列名
(2)重命名列(withColumnRenamed
)
# 重命名单个列
df_renamed = df.withColumnRenamed("old_col", "new_col")# 重命名多个列(链式调用)
df_renamed = df \.withColumnRenamed("a", "col_a") \.withColumnRenamed("b", "col_b")
(3)删除列(drop
)
# 删除单个列
df_dropped = df.drop("age")# 删除多个列
df_dropped = df.drop("age", "gender")# 删除不存在的列(不报错)
df_dropped = df.drop("nonexistent_col")
2. 行操作
(1)过滤行(filter
/where
,二者等价)
# 条件:年龄>25
df.filter(F.col("age") > 25).show()# 多条件(且:&,或:|,非:~;注意括号)
df.filter((F.col("age") > 25) & (F.col("gender") == "F")
).show()# 字符串条件(不推荐,无类型检查)
df.where("age > 25 and gender = 'F'").show()# 空值处理(过滤空值)
df.filter(F.col("birth_date").isNotNull()).show()# 过滤非空且非空字符串
df.filter(F.col("name").isNotNull() & (F.col("name") != "")
).show()
(2)去重(distinct
/dropDuplicates
)
# 所有列完全重复的行去重
df_distinct = df.distinct()# 指定列重复的行去重(保留第一条)
df_drop_dup = df.dropDuplicates(subset=["name", "gender"]) # 按name和gender去重
(3)排序(orderBy
/sort
,二者等价)
# 按年龄升序(默认)
df.orderBy("age").show()# 按年龄降序(F.desc())
df.orderBy(F.desc("age")).show()# 多列排序(年龄降序,姓名升序)
df.sort(F.col("age").desc(),F.col("name").asc()
).show()
3. 类型转换(cast
)
解决数据类型不匹配问题(如字符串转日期、字符串转数字):
# 字符串转整数(若转换失败,结果为null)
df = df.withColumn("age", F.col("age").cast(IntegerType()))# 字符串转日期(指定格式,如"yyyy-MM-dd")
df = df.withColumn("birth_date",F.to_date(F.col("birth_date"), "yyyy-MM-dd") # 比cast更灵活,支持指定格式
)# 日期转字符串
df = df.withColumn("birth_str",F.date_format(F.col("birth_date"), "yyyyMMdd") # 转为"20000101"格式
)# 字符串转浮点数
df = df.withColumn("salary", F.col("salary").cast(DoubleType()))
4. 新增列(withColumn
)
通过现有列计算或常量值新增列:
# 基于现有列计算(年龄是否成年)
df = df.withColumn("is_adult", F.col("age") >= 18)# 常量列(所有行值相同)
df = df.withColumn("source", F.lit("csv_file")) # F.lit()表示常量# 条件列(when/otherwise,类似if-else)
df = df.withColumn("age_group",F.when(F.col("age") < 18, "少年").when((F.col("age") >= 18) & (F.col("age") < 30), "青年").when((F.col("age") >= 30) & (F.col("age") < 50), "中年").otherwise("老年") # 其他情况
)
5. 字符串处理
PySpark 提供丰富的字符串函数(pyspark.sql.functions
):
函数 | 功能 | 示例 |
---|---|---|
regexp_replace | 正则替换 | F.regexp_replace("name", "A", "a") |
substring | 截取子串(索引从 1 开始) | F.substring("birth_str", 1, 4) # 取年份 |
upper /lower | 大小写转换 | F.upper("name") |
trim /ltrim /rtrim | 去除空格 | F.trim("name") |
split | 分割字符串为数组 | F.split("address", ",") |
concat /concat_ws | 拼接字符串 | F.concat_ws("-", "year", "month") |
示例:
# 清洗日期字符串(移除所有非数字字符,如"2000/01/01"→"20000101")
df = df.withColumn("cleaned_birth",F.regexp_replace(F.col("birth_date_str"), "[^0-9]", "")
)# 拆分姓名为姓和名(假设格式为" lastName, firstName")
df = df.withColumn("split_name",F.split(F.trim("name"), ", ") # 先去空格,再按", "分割
).withColumn("first_name",F.col("split_name")[1] # 取数组第二个元素
).withColumn("last_name",F.col("split_name")[0] # 取数组第一个元素
).drop("split_name") # 删除临时列
6. 日期时间处理
针对DateType
或TimestampType
列的操作:
函数 | 功能 | 示例 |
---|---|---|
to_date | 字符串转日期 | F.to_date("str", "yyyy-MM-dd") |
date_add /date_sub | 日期加减天数 | F.date_add("birth_date", 1) |
months_between | 两个日期相差月数 | F.months_between("end", "start") |
year /month /day | 提取年 / 月 / 日 | F.year("birth_date") |
current_date /current_timestamp | 当前日期 / 时间 | F.current_date() |
示例:
# 计算年龄(当前年份 - 出生年份)
df = df.withColumn("calc_age",F.year(F.current_date()) - F.year(F.col("birth_date"))
)# 计算距今天数
df = df.withColumn("days_since_birth",F.datediff(F.current_date(), F.col("birth_date"))
)
五、高级操作:聚合、连接与窗口函数
1. 聚合操作(groupBy
)
按列分组后进行统计(如计数、求和、均值等):
# 按性别分组,统计人数和平均年龄
gender_stats = df.groupBy("gender") \.agg(F.count("id").alias("total_people"), # 计数(非空id的数量)F.avg("age").alias("avg_age"), # 平均年龄F.max("salary").alias("max_salary"), # 最高薪资F.min("salary").alias("min_salary") # 最低薪资)
gender_stats.show()# 全局聚合(不分组,统计整体)
total_stats = df.agg(F.count("*").alias("total_rows"), # 总条数(包括null)F.sum("salary").alias("total_salary")
)
2. 连接操作(join
,多表关联)
合并多个 DataFrame(类似 SQL 的 JOIN),支持内连接、左连接等:
# 示例:员工表(df_emp)与部门表(df_dept)关联
df_emp = spark.createDataFrame([("1", "Alice", "10"), ("2", "Bob", "20")],["emp_id", "name", "dept_id"]
)
df_dept = spark.createDataFrame([("10", "HR"), ("20", "Tech"), ("30", "Finance")],["dept_id", "dept_name"]
)# 内连接(只保留两表都匹配的行)
inner_join = df_emp.join(df_dept,on="dept_id", # 连接键(若列名不同,用on=[df_emp.a == df_dept.b])how="inner"
)# 左连接(保留左表所有行,右表无匹配则为null)
left_join = df_emp.join(df_dept,on="dept_id",how="left" # 或"left_outer"
)# 右连接(保留右表所有行)
right_join = df_emp.join(df_dept, on="dept_id", how="right")# 全连接(保留两表所有行)
full_join = df_emp.join(df_dept, on="dept_id", how="full")
注意:连接后若有重名列(非连接键),需用alias
区分:
df_emp.alias("e").join(df_dept.alias("d"),F.col("e.dept_id") == F.col("d.dept_id"),how="inner"
).select("e.emp_id", "e.name", "d.dept_name").show() # 明确指定列来源
3. 窗口函数(Window
,分组内的精细计算)
用于实现 “分组内排序”“Top N”“累计求和” 等场景,需先定义窗口规则:
from pyspark.sql.window import Window# 示例:按部门分组,计算每个员工的薪资排名
# 1. 定义窗口:按部门分区(partitionBy),按薪资降序排序(orderBy)
window_spec = Window \.partitionBy("dept_id") \.orderBy(F.col("salary").desc())# 2. 应用窗口函数
df = df.withColumn("rank_in_dept", # 排名(相同值会占用相同名次,后续名次跳过)F.rank().over(window_spec)
).withColumn("dense_rank_in_dept", # 密集排名(相同值占用相同名次,后续名次不跳过)F.dense_rank().over(window_spec)
).withColumn("row_num_in_dept", # 行号(即使值相同,名次也唯一)F.row_number().over(window_spec)
)# 3. 取每个部门薪资前2的员工
top2_in_dept = df.filter(F.col("row_num_in_dept") <= 2)
常用窗口函数:
- 排名类:
rank()
、dense_rank()
、row_number()
- 聚合类:
sum()over()
、avg()over()
(如 “累计销售额”) - 偏移类:
lag()
(取前 n 行值)、lead()
(取后 n 行值)
六、用户自定义函数(UDF)
当内置函数无法满足需求时,可自定义函数扩展功能:
1. 普通 UDF(基于 Python 函数)
# 定义Python函数:计算姓名长度
def name_length(name):return len(name) if name is not None else 0# 注册为UDF(指定返回类型)
name_length_udf = F.udf(name_length, IntegerType())# 使用UDF
df = df.withColumn("name_len", name_length_udf(F.col("name")))
2. Pandas UDF(向量化 UDF,性能优于普通 UDF)
适用于大数据量场景,基于 Pandas Series 处理:
import pandas as pd
from pyspark.sql.functions import pandas_udf# 定义Pandas UDF(输入输出为Pandas Series)
@pandas_udf(IntegerType())
def pandas_name_length(name_series: pd.Series) -> pd.Series:return name_series.str.len().fillna(0) # 利用Pandas字符串方法# 使用Pandas UDF
df = df.withColumn("name_len", pandas_name_length(F.col("name")))
注意:UDF 会打破 Spark 的优化逻辑,尽量优先使用内置函数;必须指定返回类型,否则可能报错。
七、数据写出(持久化结果)
处理完成后,将 DataFrame 写出到文件或数据库:
1. 写出到文件
# 写出为Parquet(推荐,压缩率高,保留schema)
df.write \.mode("overwrite") # 写出模式:overwrite(覆盖)/append(追加)/ignore(忽略)/errorifexists(报错).parquet("path/to/output.parquet")# 写出为CSV(需指定header,否则无列名)
df.write \.mode("append") \.option("header", "true") \.option("sep", ",") \.csv("path/to/output.csv")# 写出为JSON
df.write.json("path/to/output.json")
2. 写出到数据库
df.write \.format("jdbc") \.option("url", "jdbc:mysql://host:port/db_name") \.option("dbtable", "target_table") \.option("user", "username") \.option("password", "password") \.mode("overwrite") \.save()
八、性能优化技巧
- 指定 schema:读取数据时手动定义 schema,避免
inferSchema
(减少 IO 和计算开销)。 - 合理使用缓存:对重复使用的 DataFrame 进行缓存(
cache()
或persist()
),减少重复计算:df_cached = df.cache() # 缓存到内存(默认MEMORY_AND_DISK级别) df_cached.count() # 触发缓存
- 减少数据量:尽早过滤(
filter
)和选择必要列(select
),避免大表全量处理。 - 调整分区:
- 读取后分区数不合理:
df.repartition(8)
(增加分区,适合大表)或df.coalesce(2)
(减少分区,不 shuffle)。 - shuffle 操作(如
groupBy
、join
)前设置spark.sql.shuffle.partitions
(根据集群资源调整,通常为核心数的 2-3 倍)。
- 读取后分区数不合理:
- 广播小表:小表与大表连接时,用
broadcast
广播小表,避免大表 shuffle:from pyspark.sql.functions import broadcast df_large.join(broadcast(df_small), on="id", how="inner") # 广播df_small
- 避免
collect()
:collect()
会将分布式数据拉取到 Driver 端,可能导致 OOM,小数据才用;大数据用take(n)
或写出到文件。
九、综合案例:用户行为数据分析
假设需分析用户行为数据(user_behavior.csv
),包含user_id
、action
(点击 / 购买)、action_time
、product_id
,目标是:
- 清洗数据(转换时间格式,过滤无效值);
- 统计每个用户的点击和购买次数;
- 计算每个用户的首购时间。
# 1. 加载数据并定义schema
schema = StructType([StructField("user_id", StringType(), False),StructField("action", StringType(), False),StructField("action_time", StringType(), False),StructField("product_id", StringType(), True)
])df = spark.read \.option("header", "true") \.schema(schema) \.csv("user_behavior.csv")# 2. 数据清洗
df_clean = df \# 过滤无效动作(只保留点击和购买).filter(F.col("action").isin(["click", "purchase"])) \# 转换时间格式(字符串→时间戳).withColumn("action_ts", F.to_timestamp("action_time", "yyyy-MM-dd HH:mm:ss")) \# 删除无效时间行.filter(F.col("action_ts").isNotNull()) \.drop("action_time") # 丢弃原字符串时间列# 3. 统计每个用户的点击和购买次数
user_action_count = df_clean.groupBy("user_id") \.pivot("action", ["click", "purchase"]) # 透视action列,转为click和purchase列.count() \.fillna(0) # 空值填充为0.withColumnRenamed("click", "click_count") \.withColumnRenamed("purchase", "purchase_count")# 4. 计算每个用户的首购时间
# 定义窗口:按用户分区,按时间升序排序
window_first_purchase = Window \.partitionBy("user_id") \.orderBy("action_ts")first_purchase = df_clean \# 只保留购买行为.filter(F.col("action") == "purchase") \# 标记每个用户的第一条购买记录.withColumn("row_num", F.row_number().over(window_first_purchase)) \.filter(F.col("row_num") == 1) \# 提取首购时间和商品.select("user_id",F.col("action_ts").alias("first_purchase_time"),"product_id")# 5. 合并结果并写出
result = user_action_count.join(first_purchase,on="user_id",how="left" # 左连接,保留所有用户(包括无购买的)
)result.write \.mode("overwrite") \.parquet("user_behavior_analysis_result")# 关闭SparkSession
spark.stop()
总结
PySpark DSL 编程以 DataFrame API 为核心,覆盖了从数据加载、清洗、转换、聚合到写出的全流程,通过链式调用实现高效的大数据处理。关键在于:
- 熟练掌握基础 API(
select
、filter
、withColumn
等); - 理解惰性执行机制,合理使用缓存;
- 灵活运用聚合、连接、窗口函数解决复杂业务问题;
- 关注性能优化,避免常见陷阱(如无意义的全表扫描、滥用 UDF)。
实际开发中,需结合具体业务场景选择合适的 API,并通过 Spark UI(默认 4040 端口)监控作业执行情况,持续优化。