【问题标题】:PySpark DF column creation with UDF to mimic np.roll function from numpy使用 UDF 创建 PySpark DF 列以模仿 numpy 中的 np.roll 函数
【发布时间】:2023-03-25 09:48:01
【问题描述】:

尝试在 PySpark UDF 中创建新列,但值为 null!

创建 DF

data_list = [['a', [1, 2, 3]], ['b', [4, 5, 6]],['c', [2, 4, 6, 8]],['d', [4, 1]],['e', [1,2]]]
all_cols = ['COL1','COL2']
df = sqlContext.createDataFrame(data_list, all_cols)
df.show()
+----+------------+
|COL1|        COL2|
+----+------------+
|   a|   [1, 2, 3]|
|   b|   [4, 5, 6]|
|   c|[2, 4, 6, 8]|
|   d|      [4, 1]|
|   e|      [1, 2]|
+----+------------+

df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)

创建函数

def cr_pair(idx_src, idx_dest):
    idx_dest.append(idx_dest.pop(0))
    return idx_src, idx_dest
lst1 = [1,2,3]
lst2 = [1,2,3]
cr_pair(lst1, lst2)
([1, 2, 3], [2, 3, 1])

创建和注册 UDF

from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

from pyspark.sql.types import ArrayType
get_idx_pairs = udf(lambda x: cr_pair(x, x), ArrayType(IntegerType()))

向 DF 添加新列

df = df.select('COL1', 'COL2',  get_idx_pairs('COL2').alias('COL3'))
df.printSchema()
root
 |-- COL1: string (nullable = true)
 |-- COL2: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- COL3: array (nullable = true)
 |    |-- element: integer (containsNull = true)

df.show()
+----+------------+------------+
|COL1|        COL2|        COL3|
+----+------------+------------+
|   a|   [1, 2, 3]|[null, null]|
|   b|   [4, 5, 6]|[null, null]|
|   c|[2, 4, 6, 8]|[null, null]|
|   d|      [4, 1]|[null, null]|
|   e|      [1, 2]|[null, null]|
+----+------------+------------+

问题出在哪里。 我在 COL3 列中得到所有值“null”。 预期的结果应该是:

+----+------------+----------------------------+
|COL1|        COL2|                        COL3|
+----+------------+----------------------------+
|   a|   [1, 2, 3]|[[1 ,2, 3], [2, 3, 1]]      |
|   b|   [4, 5, 6]|[[4, 5, 6], [5, 6, 4]]      |
|   c|[2, 4, 6, 8]|[[2, 4, 6, 8], [4, 6, 8, 2]]|
|   d|      [4, 1]|[[4, 1], [1, 4]]            |
|   e|      [1, 2]|[[1, 2], [2, 1]]            |
+----+------------+----------------------------+

【问题讨论】:

    标签: apache-spark pyspark apache-spark-sql user-defined-functions


    【解决方案1】:

    您的 UDF 应该返回 ArrayType(ArrayType(IntegerType())),因为您希望列中包含列表列表,此外它只需要一个参数:

    def cr_pair(idx_src):
        return idx_src, idx_src[1:] + idx_src[:1]
    
    get_idx_pairs = udf(cr_pair, ArrayType(ArrayType(IntegerType())))
    df.withColumn('COL3', get_idx_pairs(df['COL2'])).show(5, False)
    +----+------------+----------------------------+
    |COL1|COL2        |COL3                        |
    +----+------------+----------------------------+
    |a   |[1, 2, 3]   |[[2, 3, 1], [2, 3, 1]]      |
    |b   |[4, 5, 6]   |[[5, 6, 4], [5, 6, 4]]      |
    |c   |[2, 4, 6, 8]|[[4, 6, 8, 2], [4, 6, 8, 2]]|
    |d   |[4, 1]      |[[1, 4], [1, 4]]            |
    |e   |[1, 2]      |[[2, 1], [2, 1]]            |
    +----+------------+----------------------------+
    

    【讨论】:

    • 哇,真快。谢谢你Psidom!差不多好了。 “表演”的结果有点偏离。见下文......好吧,我无法正确复制结果,但在列表列表前面我得到“[WrappedArray(1,..”
    • 您是指... 吗?这只是打印截断。查看更新。
    • 我看到了您的编辑。如果它没有回答问题,请不要编辑答案。话虽如此,您使用的是什么 python 和 spark 版本?我用 spark 2.3.0 和 python 3.6.5 得到了很好的结果。
    • Python 2.7.13 |Anaconda 4.4.0(64 位)和 PySpark 2.2.0
    • 顺便说一句,我如何从“评论”中输入结果或代码?我已经尝试过了,它总是关闭“代码格式”。这就是我编辑你的回复的原因!!
    【解决方案2】:

    看起来您想要做的是循环移动列表中的元素。这是使用pyspark.sql.functions.posexplode()(Spark 版本 2.1 及更高版本)的非 udf 方法:

    import pyspark.sql.functions as f
    from pyspark.sql import Window
    
    w = Window.partitionBy("COL1", "COL2").orderBy(f.col("pos") == 0, "pos")
    df = df.select("*", f.posexplode("COL2"))\
        .select("COL1", "COL2", "pos", f.collect_list("col").over(w).alias('COL3'))\
        .where("pos = 0")\
        .drop("pos")\
        .withColumn("COL3", f.array("COL2", "COL3"))
    
    df.show(truncate=False)
    #+----+------------+----------------------------------------------------+
    #|COL1|COL2        |COL3                                                |
    #+----+------------+----------------------------------------------------+
    #|a   |[1, 2, 3]   |[WrappedArray(1, 2, 3), WrappedArray(2, 3, 1)]      |
    #|b   |[4, 5, 6]   |[WrappedArray(4, 5, 6), WrappedArray(5, 6, 4)]      |
    #|c   |[2, 4, 6, 8]|[WrappedArray(2, 4, 6, 8), WrappedArray(4, 6, 8, 2)]|
    #|d   |[4, 1]      |[WrappedArray(4, 1), WrappedArray(1, 4)]            |
    #|e   |[1, 2]      |[WrappedArray(1, 2), WrappedArray(2, 1)]            |
    #+----+------------+----------------------------------------------------+
    

    使用posexplode 将返回两列 - 列表中的位置 (pos) 和值 (col)。这里的诀窍是我们先按f.col("pos") == 0 订购,然后再订购"pos"。这会将数组中的第一个位置移动到列表的末尾。

    虽然此输出 打印 与您在 python 中的列表列表不同,但 COL3 的内容确实是整数列表的列表。

    df.printSchema()
    #root
    # |-- COL1: string (nullable = true)
    # |-- COL2: array (nullable = true)
    # |    |-- element: long (containsNull = true)
    # |-- COL3: array (nullable = false)
    # |    |-- element: array (containsNull = true)
    # |    |    |-- element: long (containsNull = true)
    

    更新

    WrappedArray 前缀”正是 Spark 打印嵌套列表的方式。底层数组正是您需要的。验证这一点的一种方法是调用collect() 并检查数据:

    results = df.collect()
    print([(r["COL1"], r["COL3"]) for r in results])
    #[(u'a', [[1, 2, 3], [2, 3, 1]]),
    # (u'b', [[4, 5, 6], [5, 6, 4]]),
    # (u'c', [[2, 4, 6, 8], [4, 6, 8, 2]]),
    # (u'd', [[4, 1], [1, 4]]),
    # (u'e', [[1, 2], [2, 1]])]
    

    或者,如果您将 df 转换为 pandas DataFrame:

    print(df.toPandas())
    #  COL1          COL2                          COL3
    #0    a     [1, 2, 3]        ([1, 2, 3], [2, 3, 1])
    #1    b     [4, 5, 6]        ([4, 5, 6], [5, 6, 4])
    #2    c  [2, 4, 6, 8]  ([2, 4, 6, 8], [4, 6, 8, 2])
    #3    d        [4, 1]              ([4, 1], [1, 4])
    #4    e        [1, 2]              ([1, 2], [2, 1])
    

    【讨论】:

    • 这是伟大的保尔特,谢谢!但是,我认为我的用例不需要“窗口功能”。我只需要(是)获取原始列表并将其“一个位置移位”作为一对返回。我基本上是在尝试模仿 NumPy 的 np.roll 功能。如果我能在没有“WrappedArray”前缀的情况下获得 COL3 的内容,就可以解决我的问题。再次感谢,非常感谢您的努力!
    • 我想告诉你的是“WrappedArray 前缀”就是它的打印方式。内容完全符合您的需要。我在这里使用Window,因为这种方法是(我能想到的)避免udf(udf 速度较慢)的唯一方法。
    • @TSAR 查看我收集数据的编辑。结果是您所期望的列表列表。
    • @TSAR 阅读 this post 这解释了为什么您通常尽可能避免使用 udfs。
    猜你喜欢
    • 1970-01-01
    • 2018-08-27
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多