【问题标题】:Element-wise sum of arrays across multiple columns of a data frame in Spark / Scala?Spark / Scala中数据框多列的数组元素总和?
【发布时间】:2026-02-21 10:05:01
【问题描述】:

我有一个数据框,它可以有多个数组类型的列,如“Array1”、“Array2”……等。这些数组列将具有相同数量的元素。我需要计算一个数组类型的新列,它将是数组元素的总和。我该怎么做?

Spark 版本 = 2.3

例如:

输入:

|Column1| ... |ArrayColumn2|ArrayColumn3|
|-------| --- |------------|------------|
|T1     | ... |[1, 2 , 3]  | [2, 5, 7]

输出:

|Column1| ... |AggregatedColumn|
|-------| --- |------------|
|T1.    | ... |[3, 7 , 10]

数组列的数量不固定,因此我需要一个通用的解决方案。我会有一个需要汇总的列列表。

谢谢!

【问题讨论】:

  • 所有数组列是否具有相同数量的元素?
  • 是的,数组列将具有相同数量的元素。我会更新澄清这一点的问题。

标签: scala apache-spark


【解决方案1】:

考虑使用inline 和高阶函数aggregate(在Spark 2.4+ 中可用)从数组类型的列中计算元素总和,然后使用groupBy/agg 将元素总和分组返回进入数组:

val df = Seq(
  (101, Seq(1, 2), Seq(3, 4), Seq(5, 6)),
  (202, Seq(7, 8), Seq(9, 10), Seq(11, 12))
).toDF("id", "arr1", "arr2", "arr3")

val arrCols = df.columns.filter(_.startsWith("arr")).map(col)

适用于 Spark 3.0+

df.
  withColumn("arr_structs", arrays_zip(arrCols: _*)).
  select($"id", expr("inline(arr_structs)")).
  select($"id", aggregate(array(arrCols: _*), lit(0), (acc, x) => acc + x).as("pos_elem_sum")).
  groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
  show
// +---+------------+
// | id|arr_elem_sum|
// +---+------------+
// |101|     [9, 12]|
// |202|    [27, 30]|
// +---+------------+

适用于 Spark 2.4+

df.
  withColumn("arr_structs", arrays_zip(arrCols: _*)).
  select($"id", expr("inline(arr_structs)")).
  select($"id", array(arrCols: _*).as("arr_pos_elems")).
  select($"id", expr("aggregate(arr_pos_elems, 0, (acc, x) -> acc + x)").as("pos_elem_sum")).
  groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
  show

适用于 Spark 2.3 或更低版本

val sumArrElems = udf{ (arr: Seq[Int]) => arr.sum }

df.
  withColumn("arr_structs", arrays_zip(arrCols: _*)).
  select($"id", expr("inline(arr_structs)")).
  select($"id", sumArrElems(array(arrCols: _*)).as("pos_elem_sum")).
  groupBy("id").agg(collect_list($"pos_elem_sum").as("arr_elem_sum")).
  show

【讨论】:

  • @Diamondhead,刚刚注意到您已指定使用 Spark 2.3。答案扩展到涵盖各种 Spark 版本。
【解决方案2】:

array(ArrayColumn2[0]+ArrayColumn3[0], ArrayColumn2[1]+...) 这样的 SQL 表达式可用于计算预期结果。

val df = ...

//get all array columns
val arrayCols = df.schema.fields.filter(_.dataType.isInstanceOf[ArrayType]).map(_.name)

//get the size of the first array of the first row
val firstArray = arrayCols(0)
val arraySize = df.selectExpr(s"size($firstArray)").first().getAs[Int](0)

//generate the sql expression for the sums
val sums = (for( i <-0 to arraySize-1)
  yield arrayCols.map(c=>s"$c[$i]").mkString("+")).mkString(",")
//sums = ArrayColumn2[0]+ArrayColumn3[0],ArrayColumn2[1]+ArrayColumn3[1],ArrayColumn2[2]+ArrayColumn3[2]

//create a new column using sums
df.withColumn("AggregatedColumn", expr(s"array($sums)")).show()

输出:

+-------+------------+------------+----------------+
|Column1|ArrayColumn2|ArrayColumn3|AggregatedColumn|
+-------+------------+------------+----------------+
|     T1|   [1, 2, 3]|   [2, 5, 7]|      [3, 7, 10]|
+-------+------------+------------+----------------+

使用这个单一(长)SQL 表达式将避免在网络上打乱数据,从而提高性能。

【讨论】: