【问题标题】:Multiplication of members of two arrays两个数组成员的乘法
【发布时间】:2021-06-28 14:29:38
【问题描述】:

我有下表:

from pyspark.sql import SparkSession, functions as F

spark = SparkSession.builder.getOrCreate()

cols = [  'a1',   'a2']
data = [([2, 3], [4, 5]),
        ([1, 3], [2, 4])]

df = spark.createDataFrame(data, cols)
df.show()
#  +------+------+
#  |    a1|    a2|
#  +------+------+
#  |[2, 3]|[4, 5]|
#  |[1, 3]|[2, 4]|
#  +------+------+

我知道how to multiply array by a scalar。但是如何将一个数组的成员与另一个数组的对应成员相乘呢?

想要的结果:

#  +------+------+-------+
#  |    a1|    a2|    res|
#  +------+------+-------+
#  |[2, 3]|[4, 5]|[8, 15]|
#  |[1, 3]|[2, 4]|[2, 12]|
#  +------+------+-------+

【问题讨论】:

    标签: arrays apache-spark pyspark apache-spark-sql multiplication


    【解决方案1】:

    与您的示例类似,您可以从转换函数访问第二个数组。这假设两个数组具有相同的长度:

    from pyspark.sql.functions import expr
    
    cols = [  'a1',   'a2']
    data = [([2, 3], [4, 5]),
            ([1, 3], [2, 4])]
    
    df = spark.createDataFrame(data, cols)
    
    df = df.withColumn("res", expr("transform(a1, (x, i) -> a2[i] * x)"))
    
    # +------+------+-------+
    # |    a1|    a2|    res|
    # +------+------+-------+
    # |[2, 3]|[4, 5]|[8, 15]|
    # |[1, 3]|[2, 4]|[2, 12]|
    # +------+------+-------+
    

    【讨论】:

    • 谢谢,这个版本看起来很流畅。
    【解决方案2】:

    假设您可以拥有不同大小的数组:

    from pyspark.sql import SparkSession
    from pyspark.sql.functions import expr
    
    spark = SparkSession.builder.getOrCreate()
    
    cols = ['a1', 'a2']
    data = [([2, 3], [4, 5]),
            ([1, 3], [2, 4]),
            ([1, 3], [2, 4, 6])]
    
    df = spark.createDataFrame(data, cols)
    df = df.withColumn("res", expr("transform(arrays_zip(a1, a2), x -> coalesce(x.a1 * x.a2, 0))"))
    
    df.show(truncate=False)
    # +------+---------+----------+
    # |a1    |a2       |res       |
    # +------+---------+----------+
    # |[2, 3]|[4, 5]   |[8, 15]   |
    # |[1, 3]|[2, 4]   |[2, 12]   |
    # |[1, 3]|[2, 4, 6]|[2, 12, 0]|
    # +------+---------+----------+
    

    【讨论】:

    • 谢谢,想想真是明智之举!我想我将来可能会使用它。
    【解决方案3】:

    使用用户定义函数(UDF)创建一个函数来执行乘法和调用这个函数。

    def sum(x, y):
        return [x[0] * y[0], x[1] * y[1]]
    
    sum_cols = udf(sum, ArrayType(IntegerType()))
    
    df1 = df.withColumn("res", sum_cols('a1', 'a2'))
    
    df1.show()
    
    +------+------+-------+
    |    a1|    a2|    res|
    +------+------+-------+
    |[2, 3]|[4, 5]|[8, 15]|
    |[1, 3]|[2, 4]|[2, 12]|
    +------+------+-------+
    

    https://docs.databricks.com/spark/latest/spark-sql/udf-python.html

    【讨论】:

    • 感谢您的精彩回答。在这种情况下,我决定使用无 udf 的方法。
    猜你喜欢
    • 2015-07-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-04-14
    • 2022-01-10
    相关资源
    最近更新 更多