【问题标题】:Spark RDD - Replacing the missing columns with the average of other columnsSpark RDD - 用其他列的平均值替换缺失的列
【发布时间】:2018-12-10 18:10:19
【问题描述】:

我有一个如下所示的 RDD

RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, , 2), (003, 4, , 9, 2), (003, 2, 3, 0, 1) )

第一列是合同 ID(001 和 003),我需要将具有相似合同 ID 的记录分组,并找到除合同 ID 之外的所有列的平均值,然后用这些列的平均值替换缺失的列与该合同 ID 相关。

所以,最终的输出是

RDD( (001, 1, 0, 3, 4), (001, 3, 4, 1, 7), (001, (1+3)/3 , 0, 6, 4), (003, 1, 4, 5, 7), (003, 5, 4, (5+9+0)/4 , 2), (003, 4, (4+4+3)/4 , 9, 2), (003, 2, 3, 0, 1) )

我使用合同 ID 作为密钥进行了 groupByKey,然后我就卡住了。我真的很感激任何建议。

【问题讨论】:

    标签: scala apache-spark rdd


    【解决方案1】:
    // Create the exact input data provided as a Spark DataFrame/DataSet
    val df = {
      import org.apache.spark.sql._
      import org.apache.spark.sql.types._
      import scala.collection.JavaConverters._
    
      val simpleSchema = StructType(
        StructField("a", StringType) ::
        StructField("b", IntegerType) ::
        StructField("c", IntegerType) ::
        StructField("d", IntegerType) ::
        StructField("e", IntegerType) :: Nil)
    
      val data = List(
        Row("001", 1, 0, 3, 4),
        Row("001", 3, 4, 1, 7),
        Row("001", null, 0, 6, 4),
        Row("003", 1, 4, 5, 7),
        Row("003", 5, 4, null, 2),
        Row("003", 4, null, 9, 2),
        Row("003", 2, 3, 0, 1)
      )
    
      spark.createDataFrame(data.asJava, simpleSchema)
    }
    
    // fill replaces nulls with zero, which we need for the desired averaging.    
    val avgs = df.na.fill(0).groupBy(col("a")).avg("b", "c", "d", "e").as("avgs")
    
    val output = df.as("df").join(avgs, col("df.a") === col("avgs.a")).select(col("df.a"),
      coalesce(col("df.b"), col("avg(b)")),
      coalesce(col("df.c"), col("avg(c)")),
      coalesce(col("df.d"), col("avg(d)")),
      coalesce(col("df.e"), col("avg(e)"))
      )
    
    scala> df.show()
    +---+----+----+----+---+
    |  a|   b|   c|   d|  e|
    +---+----+----+----+---+
    |001|   1|   0|   3|  4|
    |001|   3|   4|   1|  7|
    |001|null|   0|   6|  4|
    |003|   1|   4|   5|  7|
    |003|   5|   4|null|  2|
    |003|   4|null|   9|  2|
    |003|   2|   3|   0|  1|
    +---+----+----+----+---+
    
    
    scala> avgs.show()
    +---+------------------+------------------+------------------+------+
    |  a|            avg(b)|            avg(c)|            avg(d)|avg(e)|
    +---+------------------+------------------+------------------+------+
    |003|               3.0|              2.75|               3.5|   3.0|
    |001|1.3333333333333333|1.3333333333333333|3.3333333333333335|   5.0|
    +---+------------------+------------------+------------------+------+
    
    
    scala> output.show()
    +---+----------------------+----------------------+----------------------+----------------------+
    |  a|coalesce(df.b, avg(b))|coalesce(df.c, avg(c))|coalesce(df.d, avg(d))|coalesce(df.e, avg(e))|
    +---+----------------------+----------------------+----------------------+----------------------+
    |001|                   1.0|                   0.0|                   3.0|                   4.0|
    |001|                   3.0|                   4.0|                   1.0|                   7.0|
    |001|    1.3333333333333333|                   0.0|                   6.0|                   4.0|
    |003|                   1.0|                   4.0|                   5.0|                   7.0|
    |003|                   5.0|                   4.0|                   3.5|                   2.0|
    |003|                   4.0|                  2.75|                   9.0|                   2.0|
    |003|                   2.0|                   3.0|                   0.0|                   1.0|
    +---+----------------------+----------------------+----------------------+----------------------+
    

    【讨论】:

    • 感谢您的回答。在上面的示例中,我只有 5 列,如果我必须平均超过 50 列怎么办?
    【解决方案2】:

    这也可以使用 sql 中的 Window 函数来实现,而无需使用任何连接。看看这个:

    val df = Seq(
      ("001", Some(1), Some(0), Some(3), Some(4)),
      ("001", Some(3), Some(4), Some(1), Some(7)),
      ("001", None, Some(0), Some(6), Some(4)),
      ("003", Some(1), Some(4), Some(5), Some(7)),
      ("003", Some(5), Some(4), None, Some(2)),
      ("003", Some(4), None, Some(9), Some(2)),
      ("003", Some(2), Some(3), Some(0), Some(1))
    ).toDF("a","b","c","d","e")
    df.show(false)
    df.createOrReplaceTempView("avg_temp")
    
    spark.sql("""  select a, coalesce(b,sum(b) over(partition by a)/count(*) over(partition by a)) b1, coalesce( c, sum(c) over(partition by a)/count(*) over(partition by a)) c1,
                coalesce( d, sum(d) over(partition by a)/count(*) over(partition by a)) d1, coalesce( e, sum(e) over(partition by a)/count(*) over(partition by a)) e1 from avg_temp
    """).show(false)
    

    结果:

    +---+----+----+----+---+
    |a  |b   |c   |d   |e  |
    +---+----+----+----+---+
    |001|1   |0   |3   |4  |
    |001|3   |4   |1   |7  |
    |001|null|0   |6   |4  |
    |003|1   |4   |5   |7  |
    |003|5   |4   |null|2  |
    |003|4   |null|9   |2  |
    |003|2   |3   |0   |1  |
    +---+----+----+----+---+
    +---+------------------+----+---+---+
    |a  |b1                |c1  |d1 |e1 |
    +---+------------------+----+---+---+
    |003|1.0               |4.0 |5.0|7.0|
    |003|5.0               |4.0 |3.5|2.0|
    |003|4.0               |2.75|9.0|2.0|
    |003|2.0               |3.0 |0.0|1.0|
    |001|1.0               |0.0 |3.0|4.0|
    |001|3.0               |4.0 |1.0|7.0|
    |001|1.3333333333333333|0.0 |6.0|4.0|
    +---+------------------+----+---+---+
    

    【讨论】:

      猜你喜欢
      • 2017-02-24
      • 2020-08-15
      • 1970-01-01
      • 1970-01-01
      • 2018-02-05
      • 2014-10-03
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多