Aggregation

所谓聚合计算就是将数据按照某种方式聚合成组, 在每一组内部进行某种计算, 然后再根据组进行汇总. 例如: 把订单数据按照日期汇总, 计算每天的总销售额. 简单来说就是 SQL 中的 GROUP BY.

在 Spark 中聚合计算的语法非常灵活以及强大.

[1]:
import os
import sys
import site

cwd = os.getcwd()
print(f"Current directory: {cwd}")
print(f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
print(f"Current Python interpreter: {sys.executable}")
print(f"Current site-packages: {site.getsitepackages()}")

sys.path.append(os.path.join(cwd, "site-packages"))
Current directory: /home/jovyan/docs/source/02-Aggregation
Current Python version: 3.10.5
Current Python interpreter: /opt/conda/bin/python
Current site-packages: ['/opt/conda/lib/python3.10/site-packages']
[38]:
# 首先创建一个 Spark Session
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
spark
[38]:

SparkSession - in-memory

SparkContext

Spark UI

Version
v3.3.0
Master
local[*]
AppName
pyspark-shell

Basic Syntax

pyspark.sql.functions 这个模块下有非常多用于计算的函数, 这里我们先将它导入, 以便后续使用.

[39]:
import pyspark.sql.functions as func
[40]:
# Create from Python list of tuple
pdf = spark.createDataFrame(
    [
        (1, "2022-01-01", 5),
        (2, "2022-01-01", 8),
        (3, "2022-01-01", 2),
        (4, "2022-01-02", 9),
        (5, "2022-01-02", 10),
        (6, "2022-01-02", 7),
        (7, "2022-01-03", 1),
        (8, "2022-01-03", 4),
        (9, "2022-01-03", 6),
        (10, "2022-01-03", 3),
    ],
    ("order_id", "date", "amount")
)
pdf.show()
+--------+----------+------+
|order_id|      date|amount|
+--------+----------+------+
|       1|2022-01-01|     5|
|       2|2022-01-01|     8|
|       3|2022-01-01|     2|
|       4|2022-01-02|     9|
|       5|2022-01-02|    10|
|       6|2022-01-02|     7|
|       7|2022-01-03|     1|
|       8|2022-01-03|     4|
|       9|2022-01-03|     6|
|      10|2022-01-03|     3|
+--------+----------+------+

聚合计算实际上有两个步骤, 先是 聚合, 然后才是 计算. 在 Spark 中由于 Lazy Load 的特性, 我们可以先对数据进行聚合, 生成一个 pyspark.sql.group.GroupedData 聚合对象, 此时计算没有实际发生, 所以没有开销, 而且我们可以复用这个聚合对象进行不同的计算.

[70]:
gdf = pdf.groupBy(pdf.date)
gdf
[70]:
<pyspark.sql.group.GroupedData at 0xffff54ae2440>

下面的例子用到了 GroupedData 对象自带的一些聚合函数. 这只适用于聚合计算针对单一列, 比较简单的情况.

[71]:
# ( ... ) provides better readability
# 统计每个组的行数
(
    gdf.count().alias("total_orders")
).show()
+----------+-----+
|      date|count|
+----------+-----+
|2022-01-01|    3|
|2022-01-02|    3|
|2022-01-03|    4|
+----------+-----+

[72]:
# 计算每个组的销售额总额
(
    gdf.sum("amount").alias("total_amount")
).show()
+----------+-----------+
|      date|sum(amount)|
+----------+-----------+
|2022-01-01|         15|
|2022-01-02|         26|
|2022-01-03|         14|
+----------+-----------+

如果要对多个列甚至综合起来进行计算, 那么就要用到 agg 方法. 该方法和 select 类似, 支持更复杂的计算.

[73]:
gdf.agg(func.count(pdf.order_id)).show()
gdf.agg(func.sum(pdf.amount)).show()
+----------+---------------+
|      date|count(order_id)|
+----------+---------------+
|2022-01-01|              3|
|2022-01-02|              3|
|2022-01-03|              4|
+----------+---------------+

+----------+-----------+
|      date|sum(amount)|
+----------+-----------+
|2022-01-01|         15|
|2022-01-02|         26|
|2022-01-03|         14|
+----------+-----------+

[74]:
# 把两个统计数据放在一起
(
    gdf.agg(
        func.count(pdf.order_id).alias("total_orders"),
        func.sum(pdf.amount).alias("total_amounts")
    )
).show()
+----------+------------+-------------+
|      date|total_orders|total_amounts|
+----------+------------+-------------+
|2022-01-01|           3|           15|
|2022-01-02|           3|           26|
|2022-01-03|           4|           14|
+----------+------------+-------------+

[75]:
# 可以对两个列先单独计算, 再联合起来计算
# 先对两个列分别进行统计行数和求和, 最后再求乘积
gdf.agg(
    func.count(pdf.order_id) * func.sum(pdf.amount)
).show()
+----------+-------------------------------+
|      date|(count(order_id) * sum(amount))|
+----------+-------------------------------+
|2022-01-01|                             45|
|2022-01-02|                             78|
|2022-01-03|                             56|
+----------+-------------------------------+

[76]:
# 也可以对两个列先联合起来计算, 再汇总计算
# 先对每个 pair 求乘积, 最后把乘积加起来
gdf.agg(
    func.sum(pdf.order_id * pdf.amount)
).show()
+----------+------------------------+
|      date|sum((order_id * amount))|
+----------+------------------------+
|2022-01-01|                      27|
|2022-01-02|                     128|
|2022-01-03|                     123|
+----------+------------------------+

我们还能使用自定义 Python 函数对聚合后的数据进行计算. 由于数据是已经被聚合的了, 那么这个函数的输入则是一个类似列表的结构.

PySpark 支持两种 UDF:

  • 纯 Python udf, 接受的参数是一个单个值.

  • pandas_udf: 接受的参数是一个 pandas.Series, 可以理解为一个带有 index 的列表.

对于聚合计算我们通常使用 pandas_udf.

在下面的例子里我们先定义了一个 Python 函数, 接受 pandas.Series 参数. 具体的逻辑是把里面的值乘以 10 再相加. 然后我们用 pyspark.sql.functions.pandas_udf 将这个纯 Python 函数注册成一个 pandas_udf, 并且我们要显式告诉 Spark 他的返回对象是一个整数类型, 而函数的类型是用于 Aggregation 聚合计算的. 此时这个 Python 函数就已经变成一个对 column 进行计算的算子 (Operator) 了.

[77]:
import pandas as pd
from pyspark.sql.types import IntegerType

@func.pandas_udf(
    returnType="int",
    functionType=func.PandasUDFType.GROUPED_AGG,
)
def time_ten_and_sum_udf(values: pd.Series):
    total = 0
    for v in values:
        total += v * 10
    return total

gdf.agg(
    time_ten_and_sum_udf(pdf.amount)
).show()
+----------+----------------------------+
|      date|time_ten_and_sum_udf(amount)|
+----------+----------------------------+
|2022-01-01|                         150|
|2022-01-02|                         260|
|2022-01-03|                         140|
+----------+----------------------------+

Group By Multiple Columns

有时候我们需要对多个列进行聚合.

在下面的例子里我们有一个超市的销售数据. 有订单号, 日期, 商品, 以及在订单内卖出去的数量.

现在我们想知道在每一天里, 每个商品一共卖出去了多少个, 以及有多少个订单购买了这件商品.

[78]:
# Create from Python list of tuple
pdf = spark.createDataFrame(
    [
        (1, "2022-01-01", "apple", 6),
        (1, "2022-01-01", "banana", 12),
        (1, "2022-01-01", "apple", 3),
        (1, "2022-01-01", "banana", 7),
        (2, "2022-01-02", "apple", 12),
        (2, "2022-01-02", "apple", 24),
        (2, "2022-01-02", "banana", 8),
    ],
    ("order_id", "date", "item", "quantity")
)
pdf.show()
+--------+----------+------+--------+
|order_id|      date|  item|quantity|
+--------+----------+------+--------+
|       1|2022-01-01| apple|       6|
|       1|2022-01-01|banana|      12|
|       1|2022-01-01| apple|       3|
|       1|2022-01-01|banana|       7|
|       2|2022-01-02| apple|      12|
|       2|2022-01-02| apple|      24|
|       2|2022-01-02|banana|       8|
+--------+----------+------+--------+

[79]:
gdf = pdf.groupBy("date", "item")
[84]:
gdf.agg(
    func.count(pdf.order_id).alias("n_orders_has_this_item"),
    func.sum(pdf.quantity).alias("sale_quantity"),
).show()
+----------+------+----------------------+-------------+
|      date|  item|n_orders_has_this_item|sale_quantity|
+----------+------+----------------------+-------------+
|2022-01-01| apple|                     2|            9|
|2022-01-01|banana|                     2|           19|
|2022-01-02| apple|                     2|           36|
|2022-01-02|banana|                     1|            8|
+----------+------+----------------------+-------------+

Window Function

在 SQL 中的聚合查询里, 窗口函数是非常重要的功能.

在下面的例子里, 我们有一个各个部门的雇员的工资数据. 我们希望知道每个部门里工资最高的雇员是谁. 这里的 Window 就是根据 department 分组, 然后我们在 Window 内根据 salary 进行排序, 然后给每一行加上序号, 最后我们只需要取出序号等于 1 的数据即可.

[85]:
# Create from Python list of tuple
pdf = spark.createDataFrame(
    [
        ("HR", "alice", 70000),
        ("HR", "bob", 56000),
        ("IT", "cathy", 68000),
        ("IT", "david", 83000),
    ],
    ("department", "employee", "salary")
)
pdf.show()
+----------+--------+------+
|department|employee|salary|
+----------+--------+------+
|        HR|   alice| 70000|
|        HR|     bob| 56000|
|        IT|   cathy| 68000|
|        IT|   david| 83000|
+----------+--------+------+

[93]:
# 先看看给每一行标上按工资排序的行号是什么样子
from pyspark.sql.window import Window

pdf.select(
    pdf.department,
    pdf.employee,
    func.row_number().over(
        Window.partitionBy("department").orderBy(func.col("salary").desc())
    ).alias("in_dept_salary_rank")
).show()
+----------+--------+-------------------+
|department|employee|in_dept_salary_rank|
+----------+--------+-------------------+
|        HR|   alice|                  1|
|        HR|     bob|                  2|
|        IT|   david|                  1|
|        IT|   cathy|                  2|
+----------+--------+-------------------+

[94]:
# 最后输出数据. 和 SQL 中我们需要用 Sub Query 不同, 我们可以直接用 .filter 来对中间状态的表筛选数据
(
    pdf.select(
        pdf.department,
        pdf.employee,
        func.row_number().over(
            Window.partitionBy("department").orderBy(func.col("salary").desc())
        ).alias("in_dept_salary_rank")
    )
    .filter(func.col("in_dept_salary_rank") == 1)
    .show()
)
+----------+--------+-------------------+
|department|employee|in_dept_salary_rank|
+----------+--------+-------------------+
|        HR|   alice|                  1|
|        IT|   david|                  1|
+----------+--------+-------------------+

[ ]: