【问题标题】:How to aggregate map columns after groupBy?如何在 groupBy 之后聚合地图列?
【发布时间】:2021-03-31 01:05:27
【问题描述】:

我需要合并两个数据框并按键组合列。这两个datafrmaes具有相同的架构,例如:

root
|-- id: String (nullable = true)
|-- cMap: map (nullable = true)
|    |-- key: string
|    |-- value: string (valueContainsNull = true)

我想按“id”分组并将“cMap”聚合在一起以进行重复数据删除。 我试过代码:

val df = df_a.unionAll(df_b).groupBy("id").agg(collect_list("cMap") as "cMap").
rdd.map(x => {
    var map = Map[String,String]()
    x.getAs[Seq[Map[String,String]]]("cMap").foreach( y => 
        y.foreach( tuple =>
        {
            val key = tuple._1
            val value = tuple._2
            if(!map.contains(key))//deduplicate
                map += (key -> value)
        }))

    Row(x.getAs[String]("id"),map)
    })

但似乎 collect_list 不能用于映射结构:

org.apache.spark.sql.AnalysisException: No handler for Hive udf class org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList because: Only primitive type arguments are accepted but map<string,string> was passed as parameter 1..;

这个问题还有其他解决方案吗?

【问题讨论】:

  • 你能升级到 2.x 吗? 2.x 中的聚合函数不需要 Hive
  • 您的错误是您使用的是org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList,但您必须使用import org.apache.spark.sql.functions.collect_list,然后它应该可以工作

标签: scala apache-spark apache-spark-sql


【解决方案1】:

您必须首先在映射列上使用explode 函数以解构 映射到键和值列,union 结果数据集,然后distinct 进行重复数据删除,然后groupBy 使用一些自定义 Scala 编码来聚合地图。

别说了,我们来写代码吧……

给定数据集:

scala> a.show(false)
+---+-----------------------+
|id |cMap                   |
+---+-----------------------+
|one|Map(1 -> one, 2 -> two)|
+---+-----------------------+

scala> a.printSchema
root
 |-- id: string (nullable = true)
 |-- cMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

scala> b.show(false)
+---+-------------+
|id |cMap         |
+---+-------------+
|one|Map(1 -> one)|
+---+-------------+

scala> b.printSchema
root
 |-- id: string (nullable = true)
 |-- cMap: map (nullable = true)
 |    |-- key: string
 |    |-- value: string (valueContainsNull = true)

您应该首先在地图列上使用explode 函数。

explode(e: Column): Column 为给定数组或映射列中的每个元素创建一个新行。

val a_keyValues = a.select('*, explode($"cMap"))
scala> a_keyValues.show(false)
+---+-----------------------+---+-----+
|id |cMap                   |key|value|
+---+-----------------------+---+-----+
|one|Map(1 -> one, 2 -> two)|1  |one  |
|one|Map(1 -> one, 2 -> two)|2  |two  |
+---+-----------------------+---+-----+

val b_keyValues = b.select('*, explode($"cMap"))

使用以下内容,您可以获得不同的键值对,这正是您要求的重复数据删除。

val distinctKeyValues = a_keyValues.
  union(b_keyValues).
  select("id", "key", "value").
  distinct // <-- deduplicate
scala> distinctKeyValues.show(false)
+---+---+-----+
|id |key|value|
+---+---+-----+
|one|1  |one  |
|one|2  |two  |
+---+---+-----+

groupBy 的时间并创建最终的地图列。

val result = distinctKeyValues.
  withColumn("map", map($"key", $"value")).
  groupBy("id").
  agg(collect_list("map")).
  as[(String, Seq[Map[String, String]])]. // <-- leave Rows for typed pairs
  map { case (id, list) => (id, list.reduce(_ ++ _)) }. // <-- collect all entries under one map
  toDF("id", "cMap") // <-- give the columns their names
scala> result.show(truncate = false)
+---+-----------------------+
|id |cMap                   |
+---+-----------------------+
|one|Map(1 -> one, 2 -> two)|
+---+-----------------------+

请注意,从 Spark 2.0.0 开始,unionAll 已被弃用,union 是正确的联合运算符:

(从 2.0.0 版开始)使用 union()

【讨论】:

  • 感谢您的解决方案。我发现应用 select('*,explode($"xxx")) 后,记录的数量减少了很多,我猜是因为 null 或空的 Map 值。但我不确定。由于我在实际项目中还有其他列,因此我将这些列分成字符串类型的列和映射类型的列,分别聚合并连接在一起。顺便说一句,'* 的语法是什么。谢谢。
  • "顺便问一下,'*"的语法是什么 $"*"的另一个版本@
  • 这太棒了。 Spark 开发人员应该努力实现具有几个不同签名的collect_map() sql 函数。
  • @JacekLaskowski 这仅适用于特定的火花版本吗?因为我遵循了您的方法,但遇到了异常 - stackoverflow.com/questions/60750717/…
【解决方案2】:

从 Spark 3.0 开始,您可以:

  • 使用map_entries 将您的地图转换为地图条目数组
  • 使用collect_set按您的ID收集这些数组
  • 使用flatten 展平收集的数组数组
  • 然后使用map_from_entries从扁平数组重建地图

查看以下代码 sn-p 其中input 是您的输入数据框:

import org.apache.spark.sql.functions.{col, collect_set, flatten, map_entries, map_from_entries}

input
  .withColumn("cMap", map_entries(col("cMap")))
  .groupBy("id")
  .agg(map_from_entries(flatten(collect_set("cMap"))).as("cMap"))

示例

给定以下数据框输入:

+---+--------------------+
|id |cMap                |
+---+--------------------+
|1  |[k1 -> v1]          |
|1  |[k2 -> v2, k3 -> v3]|
|2  |[k4 -> v4]          |
|2  |[]                  |
|3  |[k6 -> v6, k7 -> v7]|
+---+--------------------+

上面的代码sn-p返回如下dataframe:

+---+------------------------------+
|id |cMap                          |
+---+------------------------------+
|1  |[k1 -> v1, k2 -> v2, k3 -> v3]|
|3  |[k6 -> v6, k7 -> v7]          |
|2  |[k4 -> v4]                    |
+---+------------------------------+

【讨论】:

    【解决方案3】:

    我同意@Shankar。您的代码似乎完美无缺。

    我认为您正在做的唯一错误是您导入了错误的库。

    你必须导入

    import org.apache.spark.sql.functions.collect_list
    

    但我猜你正在导入

    import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList
    

    希望我猜对了。

    【讨论】:

      猜你喜欢
      • 2020-04-29
      • 2020-06-14
      • 2016-03-16
      • 2016-02-09
      • 2023-01-23
      • 2021-05-07
      • 2017-07-04
      • 2017-06-05
      • 2020-01-07
      相关资源
      最近更新 更多