【问题标题】:Incorrect ArrayType elements inside Pyspark pandas_udfPyspark pandas_udf 中的 ArrayType 元素不正确
【发布时间】:2018-12-28 13:49:00
【问题描述】:

我正在使用 Spark 2.3.0 并在我的 Pyspark 代码中尝试使用 pandas_udf 用户定义的函数。根据https://github.com/apache/spark/pull/20114,目前支持ArrayType。我的用户定义函数是:

def transform(c):
    if not any(isinstance(x, (list, tuple, np.ndarray)) for x in c.values):
        nvalues = c.values
    else:
        nvalues = np.array(c.values.tolist())
    tvalues = some_external_function(nvalues)
    if not any(isinstance(y, (list, tuple, np.ndarray)) for y in tvalues):
        p = pd.Series(np.array(tvalues))
    else:
        p = pd.Series(list(tvalues))
    return p

transform = pandas_udf(transform, ArrayType(LongType()))

当我将此函数应用于大型 Spark Dataframe 的特定数组列时,我注意到 pandas 系列 c 的第一个元素与其他元素相比具有不同的双倍大小,而最后一个元素的大小为 0:

0       [73, 10, 223, 46, 14, 73, 14, 5, 14, 21, 10, 2...
1                [223, 46, 14, 73, 14, 5, 14, 21, 30, 16]
2                 [46, 14, 73, 14, 5, 14, 21, 30, 16, 15]
...
4695                                                   []
Name: _70, Length: 4696, dtype: object

第一个数组有 20 个元素,第二个有 10 个元素(这是正确的一个),最后一个是 0。当然,c.values 会失败,ValueError: setting an array element with a sequence.,因为数组有多种大小。

当我尝试使用相同的函数对字符串数组进行列时,所有大小都是正确的,其余的函数也是如此。

知道可能是什么问题吗?可能的错误?

更新示例:

我附上一个简单的例子,只是打印 pandas_udf 函数中的值。

from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import SparkSession

if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("testing pandas_udf")\
        .getOrCreate()

    arr = []
    for i in range(100000):
        arr.append([2,2,2,2,2])

    df = spark.createDataFrame(arr, ArrayType(LongType()))

    def transform(c):
        print(c)
        print(c.values)
        return c

    transform = pandas_udf(transform, ArrayType(LongType()))

    df = df.withColumn('new_value', transform(col('value')))
    df.show()

    spark.stop()

如果你检查执行者的日志,你会看到类似的日志:

0       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
1                                     [2, 2, 2, 2, 2]
2                                     [2, 2, 2, 2, 2]
3                                     [2, 2, 2, 2, 2]
4                                     [2, 2, 2, 2, 2]
5                                     [2, 2, 2, 2, 2]
...
9996                                  [2, 2, 2, 2, 2]
9997                                  [2, 2, 2, 2, 2]
9998                                               []
9999                                               []
Name: _0, Length: 10000, dtype: object

已解决:

如果您遇到同样的问题,请升级到 Spark 2.3.1 和 pyarrow 0.9.0.post1。

【问题讨论】:

  • 谢谢。我无法重现(Spark 2.3.1,pyarrow 0.9.0)。如果您使用早期版本,您可以更新并检查问题是否仍然存在。
  • 我有 Spark 2.3.0 和 pyarrow 0.8.0.. 但是使用 Spark 2.3.1,pyarrow 0.9.0.post1 可以完美运行!非常感谢!
  • 别提了。我猜你是can answer your own question now - 它可能对未来的用户有所帮助。

标签: apache-spark pyspark apache-spark-sql user-defined-functions


【解决方案1】:

是的,看起来 Spark 中存在错误。我的情况涉及 2.3.0 和 PyArrow 0.13.0。我唯一可用的补救方法是将数组转换为字符串,然后将其传递给 Pandas UDF。

def _identity(sample_array):
    return sample_array.apply(lambda e: [int(i) for i in e.split(',')])

array_identity_udf = F.pandas_udf(_identity,
                                  returnType=ArrayType(IntegerType()),
                                  functionType=F.PandasUDFType.SCALAR)
test_df = (spark
           .table('test_table')
           .repartition(F.ceil(F.rand(seed=1234) * 100))
           .cache())

test1_df = (test_df
           .withColumn('array_test', array_identity_udf(F.concat_ws(',', F.col('sample_array')))))

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2019-03-29
    • 2023-03-04
    • 2020-01-08
    • 1970-01-01
    • 2023-03-25
    • 2021-08-31
    • 1970-01-01
    相关资源
    最近更新 更多