【发布时间】:2021-05-10 14:01:08
【问题描述】:
问题陈述
考虑以下数据(见底部的代码生成)
+-----+-----+-------+--------+
|index|group|low_num|high_num|
+-----+-----+-------+--------+
| 0| 1| 1| 1|
| 1| 1| 2| 2|
| 2| 1| 3| 3|
| 3| 2| 1| 3|
+-----+-----+-------+--------+
然后对于给定的索引,我想计算group 中所有low_num 中的一个索引high_num 大于low_num 的次数。
例如,考虑带有index 的第二行:1。 Index:1 在 group:1 中,high_num 是 2。索引1 上的high_num 大于索引0 上的high_num,等于low_num,并且小于索引2 上的high_num。所以index: 1 的high_num 在整个组中大于low_num 一次,所以我希望答案列中的值是1。
具有所需输出的数据集
+-----+-----+-------+--------+-------+
|index|group|low_num|high_num|desired|
+-----+-----+-------+--------+-------+
| 0| 1| 1| 1| 0|
| 1| 1| 2| 2| 1|
| 2| 1| 3| 3| 2|
| 3| 2| 1| 3| 1|
+-----+-----+-------+--------+-------+
数据集生成代码
from pyspark.sql import SparkSession
spark = (
SparkSession
.builder
.getOrCreate()
)
## Example df
## Note the inclusion of "desired" which is the desired output.
df = spark.createDataFrame(
[
(0, 1, 1, 1, 0),
(1, 1, 2, 2, 1),
(2, 1, 3, 3, 2),
(3, 2, 1, 3, 1)
],
schema=["index", "group", "low_num", "high_num", "desired"]
)
可能已经解决问题的伪代码
伪代码可能如下所示:
import pyspark.sql.functions as F
from pyspark.sql.window import Window
w_spec = Window.partitionBy("group").rowsBetween(
Window.unboundedPreceding, Window.unboundedFollowing)
## F.collect_list_when does not exist
## F.current_col does not exist
## Probably wouldn't work like this anyways
ddf = df.withColumn("Counts",
F.size(F.collect_list_when(
F.current_col("high_number") > F.col("low_number"), 1
).otherwise(None).over(w_spec))
)
【问题讨论】:
标签: pyspark apache-spark-sql aggregate