【问题标题】:Calculating column value in current row of Spark Dataframe based on the calculated value of a different column in previous row using Scala使用Scala根据上一行中不同列的计算值计算Spark Dataframe当前行中的列值
【发布时间】:2022-11-11 06:54:27
【问题描述】:

假设我有一个如下所示的数据框

Id A B C D
1 100 10 20 5
2 0 5 10 5
3 0 7 2 3
4 0 1 3 7

以上需要转换为类似下面的内容

Id A B C D E
1 100 10 20 5 75
2 75 5 10 5 60
3 60 7 2 3 50
4 50 1 3 7 40

这件事通过下面提供的细节起作用

  1. 数据框现在有一个新列 E,第 1 行计算为 col(A) - (max(col(B), col(C)) + col(D)) => 100-(max(10,20) + 5) = 75
  2. Id 2 的行中,第 1 行中 col E 的值被向前推为 Col A 的值
  3. 因此,对于第 2 行,E 列被确定为 75-(max(5,10) + 5) = 60
  4. Id 3 的行类似,A 的值变为 60,col E 的新值基于此确定

    问题是,col A 的值取决于除第一行之外的前一行的值

    是否有可能使用窗口和滞后来解决这个问题

【问题讨论】:

    标签: scala apache-spark apache-spark-sql


    【解决方案1】:

    您可以在按Id 列排序的窗口上使用collect_list 函数,并获取包含Amax(B, C) + D 值的结构的累积数组(作为字段T)。然后,应用aggregate 计算列E

    请注意,在这种特殊情况下,您不能使用 lag 窗口函数,因为您希望递归地获取计算值。

    import org.apache.spark.sql.expressions.Window
    
    val df2 = df.withColumn(
      "tmp",
      collect_list(
        struct(col("A"), (greatest(col("B"), col("C")) + col("D")).as("T"))
      ).over(Window.orderBy("Id"))
    ).withColumn(
      "E",
      expr("aggregate(transform(tmp, (x, i) -> IF(i=0, x.A - x.T, -x.T)), 0, (acc, x) -> acc + x)")
    ).withColumn(
      "A",
      col("E") + greatest(col("B"), col("C")) + col("D")
    ).drop("tmp")
    
    df2.show(false)
    
    //+---+---+---+---+---+---+
    //|Id |A  |B  |C  |D  |E  |
    //+---+---+---+---+---+---+
    //|1  |100|10 |20 |5  |75 |
    //|2  |75 |5  |10 |5  |60 |
    //|3  |60 |7  |2  |3  |50 |
    //|4  |50 |1  |3  |7  |40 |
    //+---+---+---+---+---+---+
    

    您可以显示中间列tmp 以了解计算背后的逻辑。

    【讨论】:

    • 嗨,blackbishop,非常感谢您的回复。问题是“max(B, C) + D”是实际计算的一个非常简单的版本。实际上,计算涉及从前一行到当前行的多列。并且自定义聚合将变得过于复杂而无法处理。这是我的错,因为我认为它会以某种方式使用滞后来获取先前的值,然后在相同的情况下使用正常的数据帧计算。但这似乎比我想象的要复杂得多
    • 嗨@Soumya!使用简单的 Window 函数无法做到这一点,因为您的计算需要递归。也许您可以提出一个新问题,详细解释您要解决的问题。我们尝试根据您提供的元素来回答问题,遗憾的是我们无法猜测您的实际任务是否更加复杂。
    【解决方案2】:

    作为blackbishop said,您不能使用滞后函数来检索列的变化值。当您使用 scala API 时,您可以开发自己的 User-Defined Aggregate Function

    您创建以下案例类,代表您当前正在读取的行和聚合器的缓冲区:

    case class InputRow(A: Integer, B: Integer, C: Integer, D: Integer)
    
    case class Buffer(var E: Integer, var A: Integer)
    

    然后使用它们来定义您的 RecursiveAggregator 自定义聚合器:

    import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.Encoder
    
    object RecursiveAggregator extends Aggregator[InputRow, Buffer, Buffer] {
      override def zero: Buffer = Buffer(null, null)
    
      override def reduce(buffer: Buffer, currentRow: InputRow): Buffer = {
        buffer.A = if (buffer.E == null) currentRow.A else buffer.E
        buffer.E = buffer.A - (math.max(currentRow.B, currentRow.C) + currentRow.D)
        buffer
      }
    
      override def merge(b1: Buffer, b2: Buffer): Buffer = {
        throw new NotImplementedError("should be used only over ordered window")
      }
    
      override def finish(reduction: Buffer): Buffer = reduction
    
      override def bufferEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
    
      override def outputEncoder: Encoder[Buffer] = ExpressionEncoder[Buffer]
    }
    

    最后,您将 RecursiveAggregator 转换为您在 input 数据帧上应用的用户定义聚合函数:

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.{col, udaf}
    
    val recursiveAggregator = udaf(RecursiveAggregator)
    
    val window = Window.orderBy("Id")
    
    val result = input
      .withColumn("computed", recursiveAggregator(col("A"), col("B"), col("C"), col("D")).over(window))
      .select("Id", "computed.A", "B", "C", "D", "computed.E")
    

    如果您将问题的数据框作为input 数据框,您将获得以下result 数据框:

    +---+---+---+---+---+---+
    |Id |A  |B  |C  |D  |E  |
    +---+---+---+---+---+---+
    |1  |100|10 |20 |5  |75 |
    |2  |75 |5  |10 |5  |60 |
    |3  |60 |7  |2  |3  |50 |
    |4  |50 |1  |3  |7  |40 |
    +---+---+---+---+---+---+
    

    【讨论】:

    • 非常感谢您的帮助。在尝试复制时是否可以在 Spark2+ 版本中复制相同的内容。我认为“udaf”仅在 Spark3+ 中可用,但不幸的是我仍然坚持使用旧版本的 Spark :(
    • 没错,udaf 函数在 Spark 2 中不存在。您可以查看 this answer 以在 Spark 2 中使用用户定义的聚合函数。
    • 任何人都可以分享有关如何包装此 UDAF 以与 PySpark 一起使用的任何见解吗?尝试用它构建一个罐子并用 PySpark 推动它/注册它时撞到砖墙:(
    【解决方案3】:

    我在 spark 2.3.0 中尝试过,但聚合函数有错误: 用户类抛出异常:Java.io.IOError: org.apache.spark.sql.catalyst.ParserException: externous input '>' excepting {'(','SELECT','FROM'....

    我在本地尝试使用 spark 2.4.0 并且它可以工作,但在我们的 hdp 中我们有 2.3.0

    你能帮我让它在 spark 2.3.0 上工作吗

    提前致谢

    【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-08-23
    • 2017-04-14
    相关资源
    最近更新 更多