【问题标题】:How to compare value of one row with all the other rows in PySpark on grouped values如何在分组值上将一行的值与 PySpark 中的所有其他行进行比较
【发布时间】:2021-05-10 14:01:08
【问题描述】:

问题陈述

考虑以下数据(见底部的代码生成)

+-----+-----+-------+--------+
|index|group|low_num|high_num|
+-----+-----+-------+--------+
|    0|    1|      1|       1|
|    1|    1|      2|       2|
|    2|    1|      3|       3|
|    3|    2|      1|       3|
+-----+-----+-------+--------+

然后对于给定的索引,我想计算group 中所有low_num 中的一个索引high_num 大于low_num 的次数。

例如,考虑带有index 的第二行:1Index1group1 中,high_num2。索引1 上的high_num 大于索引0 上的high_num,等于low_num,并且小于索引2 上的high_num。所以index: 1 的high_num 在整个组中大于low_num 一次,所以我希望答案列中的值是1

具有所需输出的数据集

+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
|    0|    1|      1|       1|      0|
|    1|    1|      2|       2|      1|
|    2|    1|      3|       3|      2|
|    3|    2|      1|       3|      1|
+-----+-----+-------+--------+-------+

数据集生成代码

from pyspark.sql import SparkSession
spark = (
    SparkSession
    .builder
    .getOrCreate()
)
## Example df
## Note the inclusion of "desired" which is the desired output.
df = spark.createDataFrame(
    [
        (0, 1, 1, 1, 0),
        (1, 1, 2, 2, 1),
        (2, 1, 3, 3, 2),
        (3, 2, 1, 3, 1)
    ],
    schema=["index", "group", "low_num", "high_num", "desired"]
)

可能已经解决问题的伪代码

伪代码可能如下所示:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w_spec = Window.partitionBy("group").rowsBetween(
    Window.unboundedPreceding, Window.unboundedFollowing)

## F.collect_list_when does not exist
## F.current_col does not exist
## Probably wouldn't work like this anyways
ddf = df.withColumn("Counts", 
                    F.size(F.collect_list_when(
                             F.current_col("high_number") > F.col("low_number"), 1
                          ).otherwise(None).over(w_spec))
                   )

【问题讨论】:

    标签: pyspark apache-spark-sql aggregate


    【解决方案1】:

    您可以在collect_list 上执行filter,并检查其size

    import pyspark.sql.functions as F
    
    df2 = df.withColumn(
        'desired', 
        F.expr('size(filter(collect_list(low_num) over (partition by group), x -> x < high_num))')
    )
    
    df2.show()
    +-----+-----+-------+--------+-------+
    |index|group|low_num|high_num|desired|
    +-----+-----+-------+--------+-------+
    |    0|    1|      1|       1|      0|
    |    1|    1|      2|       2|      1|
    |    2|    1|      3|       3|      2|
    |    3|    2|      1|       3|      1|
    +-----+-----+-------+--------+-------+
    

    【讨论】:

      猜你喜欢
      • 2022-08-05
      • 1970-01-01
      • 2016-02-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多