【问题标题】:Calculate rolling sum of an array in PySpark using Window()?使用 Window() 计算 PySpark 中数组的滚动总和?
【发布时间】:2020-02-25 16:54:36
【问题描述】:

我想计算一个给定 unix 时间戳的 ArrayType 列的滚动总和,并将其分组为 2 秒增量。示例输入/输出如下。我认为 Window() 函数会起作用,我对 PySpark 还很陌生,完全迷路了。非常感谢任何输入!

输入:

timestamp     vars 
2             [1,2,1,2]
2             [1,2,1,2]
3             [1,1,1,2]
4             [1,3,4,2]
5             [1,1,1,3]
6             [1,2,3,5]
9             [1,2,3,5]

预期输出:

+---------+-----------------------+
|timestamp|vars                   |
+---------+-----------------------+
|2        |[2.0, 4.0, 2.0, 4.0]   |
|4        |[4.0, 8.0, 7.0, 8.0]   |
|6        |[6.0, 11.0, 11.0, 16.0]|
|10       |[7.0, 13.0, 14.0, 21.0]|
+---------+-----------------------+

谢谢!

编辑:多列可以具有相同的时间戳/它们可能不连续。 vars 的长度也可能 > 3。请寻找一个稍微通用的解决方案。

【问题讨论】:

    标签: apache-spark pyspark apache-spark-sql pyspark-dataframes


    【解决方案1】:

    对于 Spark 2.4+,您可以使用数组函数和高阶函数。此解决方案适用于不同的数组大小(如果每行之间的事件不同)。以下是解释的步骤:

    首先,按 2 秒分组,将vars 收集到一个数组列中:

    df = df.groupBy((ceil(col("timestamp") / 2) * 2).alias("timestamp")) \
           .agg(collect_list(col("vars")).alias("vars"))
    
    df.show()
    
    #+---------+----------------------+
    #|timestamp|vars                  |
    #+---------+----------------------+
    #|6        |[[1, 1, 1], [1, 2, 3]]|
    #|2        |[[1, 1, 1], [1, 2, 1]]|
    #|4        |[[1, 1, 1], [1, 3, 4]]|
    #+---------+----------------------+
    

    在这里,我们将每个连续的 2 秒分组,并将 vars 数组收集到一个新列表中。 现在,使用 Window 规范,您可以收集累积值并使用 flatten 来展平子数组:

    w = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    df = df.withColumn("vars", flatten(collect_list(col("vars")).over(w)))
    df.show()
    
    #+---------+------------------------------------------------------------------+
    #|timestamp|vars                                                              |
    #+---------+------------------------------------------------------------------+
    #|2        |[[1, 1, 1], [1, 2, 1]]                                            |
    #|4        |[[1, 1, 1], [1, 2, 1], [1, 1, 1], [1, 3, 4]]                      |
    #|6        |[[1, 1, 1], [1, 2, 1], [1, 1, 1], [1, 3, 4], [1, 1, 1], [1, 2, 3]]|
    #+---------+------------------------------------------------------------------+
    

    最后,使用aggregate 函数和zip_with 对数组求和:

    t = "aggregate(vars, cast(array() as array<double>), (acc, a) -> zip_with(acc, a, (x, y) -> coalesce(x, 0) + coalesce(y, 0)))"
    
    df.withColumn("vars", expr(t)).show(truncate=False)
    
    #+---------+-----------------+
    #|timestamp|vars             |
    #+---------+-----------------+
    #|2        |[2.0, 3.0, 2.0]  |
    #|4        |[4.0, 7.0, 7.0]  |
    #|6        |[6.0, 10.0, 11.0]|
    #+---------+-----------------+
    

    综合起来:

    from pyspark.sql.functions import ceil, col, collect_list, flatten, expr
    from pyspark.sql import Window
    
    w = Window.orderBy("timestamp").rowsBetween(Window.unboundedPreceding, Window.currentRow)
    t = "aggregate(vars, cast(array() as array<double>), (acc, a) -> zip_with(acc, a, (x, y) -> coalesce(x, 0) + coalesce(y, 0)))"
    
    nb_seconds = 2
    
    df.groupBy((ceil(col("timestamp") / nb_seconds) * nb_seconds).alias("timestamp")) \
      .agg(collect_list(col("vars")).alias("vars")) \
      .withColumn("vars", flatten(collect_list(col("vars")).over(w))) \
      .withColumn("vars", expr(t)).show(truncate=False)
    

    【讨论】:

    • 在表达式中使用带有 zip 的聚合是一种将所有内容组合在一起的好方法。不错的答案
    • 谢谢!你能帮我理解如何使列表更通用 > 3 或不是每 2 秒。如果我想每 5 秒甚至 1 小时进行一次分组怎么办?
    • @justneedsomehelppls 它也适用于列表 > 3,您可以设置要分组的秒数,如代码 nb_seconds = 2 中所示。
    • 收集到的 number 个带有collect_list 的数组可能会变得太大。
    • 这非常有用。谢谢!我唯一担心的是每个单独的列表最长可达 100 万,所以这可能会变得很昂贵。不确定此解决方案的可扩展性。
    【解决方案2】:

    使用sum 窗口函数计算运行总和,使用row_number 选择每隔一个时间戳行。

    from pyspark.sql import Window
    w = Window.orderBy(col('timestamp'))
    result = df.withColumn('summed_vars',array([sum(col('vars')[i]).over(w) for i in range(3)])) #change the value 3 as desired
    result.filter(col('rnum')%2 == 0).select('timestamp','summed_vars').show()
    

    根据您的时间间隔,根据需要更改%2

    编辑:使用window 按时间间隔分组。假设timestamp 列的数据类型为timestamp

    from pyspark.sql import Window
    from pyspark.sql.functions import window,sum,row_number,array,col 
    w = Window.orderBy(col('timestamp'))
    result = df.withColumn('timestamp_interval',window(col('timestamp'),'2 second')) \
               .withColumn('summed_vars',array(*[sum(col('vars')[i]).over(w) for i in range(4)])) 
    w1 = Window.partitionBy(col('timestamp_interval')).orderBy(col('timestamp').desc())
    final_result = result.withColumn('rnum',row_number().over(w1))
    final_result.filter(col('rnum')==1).drop(*['rnum','vars']).show()
    

    【讨论】:

    • 如果有多个列具有相同的时间戳或时间戳不是唯一的怎么办?这也需要灵活地处理不同长度的输入(不仅仅是长度为 3 的数组)和不同的时间间隔(如 1 小时)。你能帮我把这个更通用吗?
    • 您应该将所有这些详细信息添加到问题中,而不是在得到答案之后。
    • 已更新。谢谢!
    • 当多行具有相同的时间戳时,期望的行为是什么?你能举个例子吗?
    • 如果 timestamp 的值以 2 而不是 1 开头,则按 col('rnum')%2 == 0 过滤将不起作用。
    猜你喜欢
    • 1970-01-01
    • 2018-03-27
    • 2020-06-09
    • 1970-01-01
    • 2016-08-12
    • 1970-01-01
    • 1970-01-01
    • 2023-03-09
    • 2022-11-14
    相关资源
    最近更新 更多