【问题标题】:Filter array column content过滤数组列内容
【发布时间】:2018-11-08 02:27:57
【问题描述】:

我正在使用 pyspark 2.3.1 并希望使用表达式而不是使用 udf 过滤数组元素:

>>> df = spark.createDataFrame([(1, "A", [1,2,3,4]), (2, "B", [1,2,3,4,5])],["col1", "col2", "col3"])
>>> df.show()
+----+----+---------------+
|col1|col2|           col3|
+----+----+---------------+
|   1|   A|   [1, 2, 3, 4]|
|   2|   B|[1, 2, 3, 4, 5]|
+----+----+---------------+

下面显示的表达式是错误的,我想知道如何告诉 spark 从 col3 中的数组中删除小于 3 的任何值。我想要类似的东西:

>>> filtered = df.withColumn("newcol", expr("filter(col3, x -> x >= 3)")).show()
>>> filtered.show()
+----+----+---------+
|col1|col2|   newcol|
+----+----+---------+
|   1|   A|   [3, 4]|
|   2|   B|[3, 4, 5]|
+----+----+---------+

我已经有一个udf解决方案,但是很慢(> 10亿数据行):

largerThan = F.udf(lambda row,max: [x for x in row if x >= max], ArrayType(IntegerType()))
df = df.withColumn('newcol', size(largerThan(df.queries, lit(3))))

欢迎任何帮助。非常感谢您。

【问题讨论】:

  • 您不能遍历数组。您可以explodefiltercollect_list 来避免udf,但这也是一项昂贵的操作。也可以序列化为rdd。见related

标签: apache-spark pyspark pyspark-sql


【解决方案1】:

火花

在 PySpark 中没有*合理的替换 udf

火花 >= 2.4

您的代码:

expr("filter(col3, x -> x >= 3)")

可以按原样使用。

参考

Querying Spark SQL DataFrame with complex types


* 考虑到爆炸或转换到 RDD 的成本,udf 几乎完全是可取的。

【讨论】:

    猜你喜欢
    • 2018-12-23
    • 2018-05-15
    • 2017-11-03
    • 1970-01-01
    • 1970-01-01
    • 2021-12-04
    • 2011-01-20
    • 2019-10-01
    • 1970-01-01
    相关资源
    最近更新 更多