【发布时间】:2017-07-28 14:25:28
【问题描述】:
我有一个包含列的数据集:id,timestamp,x,y
id timestamp x y
0 1443489380 100 1
0 1443489390 200 0
0 1443489400 300 0
0 1443489410 400 1
我定义了一个窗口规范:w = Window.partitionBy("id").orderBy("timestamp")
我想做这样的事情。创建一个新列,将当前行的 x 与下一行的 x 相加。
如果总和 >= 500 则设置新列 = BIG 否则 SMALL。
df = df.withColumn("newCol",
when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
但是,我想在此之前过滤数据而不影响原始df。
[只有 y =1 的行才会应用上面的代码]
所以将应用上述代码的数据只有这 2 行。
0 , 1443489380, 100 , 1
0 , 1443489410, 400 , 1
我已经这样做了,但是太糟糕了。
df2 = df.filter(df.y == 1)
df2 = df2.withColumn("newCol",
when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
df = df.join(df2, ["id","timestamp"], "outer")
我想做这样的事情,但这是不可能的,因为它会导致 AttributeError: 'DataFrame' object has no attribute 'when'
df = df.withColumn("newCol", df.filter(df.y == 1)
.when(df.x + lag(df.x,-1).over(w) >= 500 , "BIG")
.otherwise("SMALL") )
总之,我只想在 sum x 和下一个 x 之前只对 y =1 的行做一个临时过滤。
【问题讨论】:
-
你已经从 pyspark.sql.functions 导入了,对吧?
-
我已经通过 from pyspark.sql.functions import lag 导入,当