【问题标题】:pyspark window function calculation issue with avg methodavg方法的pyspark窗口函数计算问题
【发布时间】:2020-07-22 16:05:05
【问题描述】:

我有一个如下的输入数据框:

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  02    |40     |50    
1002      |  02    |50     |60    
1003      |  02    |60     |70
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

使用下面的代码,我创建了两个使用窗口函数进行平均的新列:

rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn("value1_1", F.avg("value1").over(rnum))
df = df.withColumn("value1_2", F.avg("value2").over(rnum))

输出:

partner_id|month_id|value1 |value2|value1_1|value2_2
1001      |  01    |10     |20    |10      |20
1002      |  01    |20     |30    |20      |30
1003      |  01    |30     |40    |30      |40
1001      |  02    |40     |50    |25      |35
1002      |  02    |50     |60    |35      |45
1003      |  02    |60     |70    |45      |55
1001      |  03    |70     |80    |40      |50
1002      |  03    |80     |90    |50      |60
1003      |  03    |90     |100   |60      |70

使用 pyspark Window 函数的 value1 和 value2 列的累积平均值表现良好。 但是,如果我们在下面的输入中错过了一个月的数据,那么下个月的平均计算应该基于月份。而不是正常的平均值。 例如,如果输入如下(缺少 02 月数据)

partner_id|month_id|value1 |value2
1001      |  01    |10     |20    
1002      |  01    |20     |30    
1003      |  01    |30     |40
1001      |  03    |70     |80    
1002      |  03    |80     |90    
1003      |  03    |90     |100

然后第三个月记录的平均计算发生如下:例如:(70 + 10)/2 但是,如果缺少某些月份值,正确的平均方法是什么???

【问题讨论】:

  • 你能展示“错误”的输出吗?

标签: python dataframe pyspark average pyspark-dataframes


【解决方案1】:

如果您使用的是 spark 2.4+。您可以使用序列函数和数组函数。 这个解决方案的灵感来自link

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w = Window().partitionBy("partner_id")

df1 = (
    df.withColumn(
        "month_seq",
        F.sequence(F.min("month_id").over(w), F.max("month_id").over(w), F.lit(1)),
    )
    .groupBy("partner_id")
    .agg(
        F.collect_list("month_id").alias("month_id"),
        F.collect_list("value1").alias("value1"),
        F.collect_list("value2").alias("value2"),
        F.first("month_seq").alias("month_seq"),
    )
    .withColumn("month_seq", F.array_except("month_seq", "month_id"))
    .withColumn("month_id", F.flatten(F.array("month_id", "month_seq")))
    .drop("month_seq")
    .withColumn("zip", F.explode(F.arrays_zip("month_id", "value1", "value2")))
    .select(
        "partner_id",
        "zip.month_id",
        F.when(F.col("zip.value1").isNull(), F.lit(0))
        .otherwise(F.col("zip.value1"))
        .alias("value1"),
        F.when(F.col("zip.value2").isNull(), F.lit(0))
        .otherwise(F.col("zip.value2"))
        .alias("value2"),
    )
    .orderBy("month_id")
)

rnum = (
    Window.partitionBy("partner_id")
    .orderBy("month_id")
    .rangeBetween(Window.unboundedPreceding, 0)
)

df2 = df1.withColumn("value1_1", F.avg("value1").over(rnum)).withColumn(
    "value1_2", F.avg("value2").over(rnum)
)

结果:

df2.show()

# +----------+--------+------+------+------------------+------------------+
# |partner_id|month_id|value1|value2|          value1_1|          value1_2|
# +----------+--------+------+------+------------------+------------------+
# |      1002|       1|    10|    20|              10.0|              20.0|
# |      1002|       2|     0|     0|               5.0|              10.0|
# |      1002|       3|    80|    90|              30.0|36.666666666666664|
# |      1001|       1|    10|    10|              10.0|              10.0|
# |      1001|       2|     0|     0|               5.0|               5.0|
# |      1001|       3|    70|    80|26.666666666666668|              30.0|
# |      1003|       1|    30|    40|              30.0|              40.0|
# |      1003|       2|     0|     0|              15.0|              20.0|
# |      1003|       3|    90|   100|              40.0|46.666666666666664|
# +----------+--------+------+------+------------------+------------------+

【讨论】:

  • 如果您不想从结果集中获得第二个月。您可以删除那些具有 value1 或 value 2 = 0 的行
【解决方案2】:

Spark 不够聪明,无法理解缺少一个月,因为它甚至不知道一个月可能是什么。

如果您希望“缺失”月份包含在平均计算中,则需要生成缺失数据。

只需使用数据框 ["month_id", "defaultValue"] 执行完全外连接,其中 month_id 是 1 到 12 之间的值且 defaultValue = 0。


另一种解决方案,不是执行平均值,而是执行值的总和,然后除以月份数。

【讨论】:

  • 谢谢。我们可以创建与其他月份相同数量的虚拟记录吗???
猜你喜欢
  • 1970-01-01
  • 2023-02-23
  • 2018-08-27
  • 1970-01-01
  • 2020-04-03
  • 2019-04-23
  • 1970-01-01
  • 2021-11-07
  • 2022-11-15
相关资源
最近更新 更多