【问题标题】:filter a list in pyspark dataframe过滤 pyspark 数据框中的列表
【发布时间】:2020-08-19 21:27:35
【问题描述】:

我在 pyspark (v2.4.5) 数据框中有一个句子列表,其中包含一组匹配的分数。句子和分数采用列表形式。

df=spark.createDataFrame(
    [
        (1, ['foo1','foo2','foo3'],[0.1,0.5,0.6]), # create your data here, be consistent in the types.
        (2, ['bar1','bar2','bar3'],[0.5,0.7,0.7]),
        (3, ['baz1','baz2','baz3'],[0.1,0.2,0.3]),
    ],
    ['id', 'txt','score'] # add your columns label here
)
df.show()
+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
|  3|[baz1, baz2, baz3]|[0.1, 0.2, 0.3]|
+---+------------------+---------------+

我想过滤并只返回那些得分 >=0.5 的句子。

+---+------------------+---------------+
| id|               txt|          score|
+---+------------------+---------------+
|  1|      [foo2, foo3]|     [0.5, 0.6]|
|  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
+---+------------------+---------------+

有什么建议吗?

我尝试了pyspark dataframe filter or include based on list,但无法让它在我的实例中运行

【问题讨论】:

    标签: list filter pyspark


    【解决方案1】:

    使用 spark 2.4+ ,您可以访问高阶函数,因此您可以过滤带有条件的压缩数组,然后过滤掉空白数组:

    import pyspark.sql.functions as F
    
    e = F.expr('filter(arrays_zip(txt,score),x-> x.score>=0.5)')
    df.withColumn("txt",e.txt).withColumn("score",e.score).filter(F.size(e)>0).show()
    

    +---+------------------+---------------+
    | id|               txt|          score|
    +---+------------------+---------------+
    |  1|      [foo2, foo3]|     [0.5, 0.6]|
    |  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
    +---+------------------+---------------+
    

    【讨论】:

    • ^这应该是 spark2.4+ 的接受答案
    【解决方案2】:

    试试这个,我想不出没有 UDF 的方法:

    from pyspark.sql.types import ArrayType, BooleanType, StringType()
    
    # UDF for boolean index
    filter_udf = udf(lambda arr: [True if x >= 0.5 else False for x in arr], ArrayType(BooleanType()))
    
    # UDF for filtering on the boolean index
    filter_udf_bool = udf(lambda col_arr, bool_arr: [x for (x,y) in zip(col_arr,bool_arr) if y], ArrayType(StringType()))
    
    df2 = df.withColumn("test", filter_udf("score"))
    df3 = df2.withColumn("txt", filter_udf_bool("txt", "test")).withColumn("score", filter_udf_bool("score", "test"))
    

    输出:

    # Further filtering for empty arrays:
    df3.drop("test").filter(F.size(F.col("txt")) > 0).show()
    
    +---+------------------+---------------+
    | id|               txt|          score|
    +---+------------------+---------------+
    |  1|      [foo2, foo3]|     [0.5, 0.6]|
    |  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
    +---+------------------+---------------+
    

    实际上,您也可以通过将 UDF 合并为一个来概括 UDF。为了简单起见,我将其拆分。

    【讨论】:

      【解决方案3】:

      在 spark 中,用户定义的函数被视为黑盒,因为催化剂优化器无法优化 udf 内的代码。所以尽可能避免使用 udf。

      这是一个不使用 UDF 的示例

      df.withColumn('combined',f.explode(f.arrays_zip('txt','score'))).filter(f.col('combined.score')>=0.5).groupby('id').agg(f.collect_list('combined.txt').alias('txt'),f.collect_list('combined.score').alias('score')).show()
      
      +---+------------------+---------------+
      | id|               txt|          score|
      +---+------------------+---------------+
      |  1|      [foo2, foo3]|     [0.5, 0.6]|
      |  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
      +---+------------------+---------------+
      

      希望它有效。

      【讨论】:

      • 我正在尝试使用arrays_zip,但我不知道如何过滤,您已经使用combined.score 完成了它。这是一个很好的答案。
      • 使用explode + collect_list is expensive - udf 是首选,但我认为有一种方法可以使用高阶函数。
      【解决方案4】:

      score是一种数组,需要用谓词进一步过滤。

      代码 sn-p 过滤数组列:

      def score_filter(row):
          score_filtered = [s for s in row.score if s >= 0.5]
          if len(score_filtered) > 0:
              yield row
      
      
      filtered = df.rdd.flatMap(score_filter).toDF()
      
      filtered.show()
      

      输出:

      +---+------------------+---------------+
      | id|               txt|          score|
      +---+------------------+---------------+
      |  1|[foo1, foo2, foo3]|[0.1, 0.5, 0.6]|
      |  2|[bar1, bar2, bar3]|[0.5, 0.7, 0.7]|
      +---+------------------+---------------+
      

      【讨论】:

      • 感谢您的建议,不幸的是,这在第一行中引入了 0.1 的值('foo1' - 应该被过滤掉
      • @alex 是的,但是去掉 0.1 和 txt 列 foo1 更多的是根据条件过滤掉过滤方法中的数据的逻辑。
      猜你喜欢
      • 2020-09-22
      • 1970-01-01
      • 2019-03-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-06-28
      • 1970-01-01
      • 2021-06-27
      相关资源
      最近更新 更多