【问题标题】:Apply function (scanLeft) to partition to create new column in dataframe将函数(scanLeft)应用于分区以在数据框中创建新列
【发布时间】:2019-09-24 11:54:11
【问题描述】:

我想对数据框的一列执行 scanLeft 类型的操作。 Scanleft 不可并行化,但在我的情况下,我只想将此函数应用于已经在同一分区中的元素。因此可以在每个分区中并行执行操作。 (无数据洗牌)

考虑以下示例:

| partitionKey  | orderColumn   | value     | scanLeft(0)(_+_)  |
|-------------- |-------------  |-------    |------------------ |
| 1             | 1             | 1         | 1                 |
| 1             | 2             | 2         | 3                 |
| 2             | 1             | 3         | 3                 |
| 2             | 2             | 4         | 7                 |
| 1             | 3             | 5         | 8                 |
| 2             | 3             | 6         | 13                |

我想scanLeft同一分区内的值,并创建一个新列来存储结果。

我现在的代码看起来像这样:

    inDataframe
      .repartition(col("partitionKey"))
      .foreachPartition{
      partition =>
        partition.map(row => row(1).asInstanceOf[Double])
      .scanLeft(0.0)(_+_)
      .foreach(println(_))
    })

这会根据需要聚合值并打印出结果,但是我想将这些值添加为数据框的新列

知道怎么做吗?

----编辑---- 真正的用例是计算时间加权收益率(https://www.investopedia.com/terms/t/time-weightedror.asp) 预期的输入看起来像这样:

| product   | valuation date    | daily return  |
|---------  |----------------   |-------------- |
| 1         | 2019-01-01        | 0.1           |
| 1         | 2019-01-02        | 0.2           |
| 1         | 2019-01-03        | 0.3           |
| 2         | 2019-01-01        | 0.4           |
| 2         | 2019-01-02        | 0.5           |
| 2         | 2019-01-03        | 0.6           |

我想计算当前日期之前所有日期的每个产品的累计回报。 Dataframe 按产品分区,分区按估价日期排序。我已经编写了要传递给 scanLeft 的聚合函数:

  def chain_ret (x: Double, y: Double): Double = {
    (1 + x) * (1 + y) - 1
  }

预期返回数据:

| product   | valuation date    | daily return  | cumulated return  |
|---------  |----------------   |-------------- |------------------ |
| 1         | 2019-01-01        | 0.1           | 0.1               |
| 1         | 2019-01-02        | 0.2           | 0.32              |
| 1         | 2019-01-03        | 0.3           | 0.716             |
| 2         | 2019-01-01        | 0.4           | 0.4               |
| 2         | 2019-01-02        | 0.5           | 1.1               |
| 2         | 2019-01-03        | 0.6           | 2.36              |

我已经通过过滤给定日期范围的数据框并应用 UDAF 来解决这个问题。 (看下面)它很长,我认为使用 scanLeft 会快得多!

    while(endPeriod.isBefore(end)) {
      val filtered = inDataframe
        .where("VALUATION_DATE >= '" + start + "' AND VALUATION_DATE <= '" + endPeriod + "'")
      val aggregated = aggregate_returns(filtered)
        .withColumn("VALUATION_DATE", lit(Timestamp.from(endPeriod)).cast(TimestampType))
      df_ret = df_ret.union(aggregated)
      endPeriod = endPeriod.plus(1, ChronoUnit.DAYS)
    }

 def aggregate_returns(inDataframe: DataFrame): DataFrame = {
    val groupedByKey = inDataframe
      .groupBy("product")
    groupedByKey
      .agg(
        returnChain(col("RETURN_LOCAL")).as("RETURN_LOCAL_CUMUL"),
        returnChain(col("RETURN_FX")).as("RETURN_FX_CUMUL"),
        returnChain(col("RETURN_CROSS")).as("RETURN_CROSS_CUMUL"),
        returnChain(col("RETURN")).as("RETURN_CUMUL")
      )

class ReturnChain extends UserDefinedAggregateFunction{

  // Defind the schema of the input data
  override def inputSchema: StructType =
    StructType(StructField("return", DoubleType) :: Nil)

  // Define how the aggregates types will be
  override def bufferSchema: StructType = StructType(
    StructField("product", DoubleType) :: Nil
  )

  // define the return type
  override def dataType: DataType = DoubleType

  // Does the function return the same value for the same input?
  override def deterministic: Boolean = true

  // Initial values
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.toDouble
  }

  // Updated based on Input
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = (1.toDouble + buffer.getAs[Double](0)) * (1.toDouble + input.getAs[Double](0))
  }

  // Merge two schemas
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
  }

  // Output
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0) - 1.toDouble
  }
}

【问题讨论】:

  • 您能否添加一些输入和预期数据,可能还有其他方法可以做到这一点。
  • @ShankarKoirala 我添加了真正的用例
  • 您应该注意:如果 tha 数据帧被 partitionKey 分区,这并不意味着 1 个分区仅包含 1 个 partitionKey,而是 1 个 partitionKey 仅存在于 1 个分区中。因此,您还需要在 mapPartitions 内使用 groupBy.... 一般来说,我会尝试以某种方式使用窗口函数来解决它

标签: scala dataframe apache-spark partition foldleft


【解决方案1】:

foreachPartition 不返回任何内容,您需要使用 .mapPartition() 代替

foreachPartition 和 mapPartition 的区别与 map 和 foreach 的区别相同。在这里寻找好的解释Foreach vs Map in Scala

【讨论】:

  • 好的,我使用 forEach,因为我正在打印 scanLeft 的结果。你知道如何用 mapPartitions 创建一个新列吗?
  • 第一个问题。你真的需要使用重新分区吗?
  • 感谢您的帮助!我添加了我的用例的详细说明。实际上数据来自Cassandra db,并按产品分区存储。
猜你喜欢
  • 1970-01-01
  • 2019-08-21
  • 1970-01-01
  • 2016-06-15
  • 1970-01-01
  • 2013-02-13
  • 1970-01-01
  • 2013-04-20
  • 2019-05-01
相关资源
最近更新 更多