【问题标题】:Mapping multiple columns to a single key in a Spark dataframe将多列映射到 Spark 数据框中的单个键
【发布时间】:2019-07-15 16:17:34
【问题描述】:

我有一个如下所示的 Spark 数据框:

+------+-----+-----+
|acctId|vehId|count|
+------+-----+-----+
|     1|  666|    1|
|     1|  777|    3|
|     1|  888|    2|
|     1|  999|    3|
|     2|  777|    1|
|     2|  888|    3|
|     2|  999|    1|
|     3|  777|    4|
|     3|  888|    2|
+------+-----+-----+

我想将每个 acctId 的 vehId 映射到其计数,并将其存储回数据框中,因此最终结果如下所示:

+------+---------------------------------------------+
|acctId| map                                         |
+------+---------------------------------------------+
|     1| Map(666 -> 1, 777 -> 3, 888 -> 2, 999 -> 3) |
|     2| Map(777 -> 1, 888 -> 3, 999 -> 1)           |
|     3| Map(777 -> 4, 888 -> 2)                     |
+------+---------------------------------------------+

最好的方法是什么?

我尝试将数据框转换为 RDD 并在行上执行映射,但我不确定如何将每个映射聚合回单个 acctId。总的来说,我是 Spark 和数据帧的新手,但我已尽我所能尝试找到类似的问题——如果这是一个非常常见的问题,我深表歉意。

供您参考/使用,这是我生成测试数据的方式:

val testData = Seq(
    (1, 999),
    (1, 999),
    (2, 999),
    (1, 888),
    (2, 888),
    (3, 888),
    (2, 888),
    (2, 888),
    (1, 888),
    (1, 777),
    (1, 666),
    (3, 888),
    (1, 777),
    (3, 777),
    (2, 777),
    (3, 777),
    (3, 777),
    (1, 999),
    (3, 777),
    (1, 777)
).toDF("acctId", "vehId")

val grouped = testData.groupBy("acctId", "vehId").count

【问题讨论】:

    标签: apache-spark apache-spark-sql


    【解决方案1】:

    我认为您必须使用 double groupBy 如下所示

    val testData = Seq(
      (1, 999),
      (1, 999),
      (2, 999),
      (1, 888),
      (2, 888),
      (3, 888),
      (2, 888),
      (2, 888),
      (1, 888),
      (1, 777),
      (1, 666),
      (3, 888),
      (1, 777),
      (3, 777),
      (2, 777),
      (3, 777),
      (3, 777),
      (1, 999),
      (3, 777),
      (1, 777)
    ).toDF("acctId", "vehId")
    
    //udf to convert list to map
    val listToMap = udf((input: Seq[Row]) => input.map(row => (row.getAs[Int](0), row.getAs[Long](1))).toMap)
    
    val resultDF = testData.groupBy("acctId", "vehId")
      .agg(count("acctId").cast("long").as("count"))
      .groupBy("acctId")
      .agg(collect_list(struct("vehId", "count")) as ("map"))
      .withColumn("map", listToMap($"map"))
    

    输出:

    resultDF.show(false)
    +------+----------------------------------------+
    |acctId|map                                     |
    +------+----------------------------------------+
    |1     |[777 -> 3, 666 -> 1, 999 -> 3, 888 -> 2]|
    |3     |[777 -> 4, 888 -> 2]                    |
    |2     |[777 -> 1, 999 -> 1, 888 -> 3]          |
    +------+----------------------------------------+
    

    架构:

    resultDF.printSchema()
    root
     |-- acctId: integer (nullable = false)
     |-- map: map (nullable = true)
     |    |-- key: integer
     |    |-- value: long (valueContainsNull = false)
    

    【讨论】:

    • @Toy_Reid 我相信 map 在这里不是线程安全的,所以我不认为 map 列中的值是一致的。请通过多次验证确保有预期的准确数据
    • @Girish501 最好能在这里解释一下地图如何不是线程安全的。
    • @Girish501 如果您能解释为什么会这样,我将不胜感激!我的测试表明这是准确的,但这将在大型数据集上运行,因此有更多关于如何测试的信息会很棒。
    猜你喜欢
    • 2019-03-16
    • 2019-03-21
    • 1970-01-01
    • 2015-12-27
    • 1970-01-01
    • 1970-01-01
    • 2018-11-10
    • 2021-04-30
    • 1970-01-01
    相关资源
    最近更新 更多