【问题标题】:Spark UDAF generics type mismatchSpark UDAF 泛型类型不匹配
【发布时间】:2018-03-10 12:28:39
【问题描述】:

我正在尝试在 Spark(2.0.1,Scala 2.11)上创建一个 UDAF,如下所示。这本质上是聚合元组并输出Map

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{Row, Column}

class mySumToMap[K, V](keyType: DataType, valueType: DataType) extends UserDefinedAggregateFunction {
  override def inputSchema = new StructType()
    .add("a_key", keyType)
    .add("a_value", valueType)

  override def bufferSchema = new StructType()
    .add("buffer_map", MapType(keyType, valueType))

  override def dataType = MapType(keyType, valueType)

  override def deterministic = true 

  override def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = Map[K, V]()
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    // input :: 0 = a_key (k), 1 = a_value
    if ( !(input.isNullAt(0)) ) {

      val a_map = buffer(0).asInstanceOf[Map[K, V]]
      val k = input.getAs[K](0)  // get the value of position 0 of the input as string (a_key)

      // I've split these on purpose to show that return values are all of type V
      val new_v1: V = a_map.getOrElse(k, 0.asInstanceOf[V])
      val new_v2: V = input.getAs[V](1)
      val new_v: V = new_v1 + new_v2

      buffer(0) = if (new_v != 0) a_map + (k -> new_v) else a_map - k
    }
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    val map1: Map[K, V] = buffer1(0).asInstanceOf[Map[K, V]]
    val map2: Map[K, V] = buffer2(0).asInstanceOf[Map[K, V]]

    buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
  }

  override def evaluate(buffer: Row) = buffer(0).asInstanceOf[Map[K, V]]

}

但是当我编译这个时,我看到以下错误:

<console>:74: error: type mismatch;
 found   : V
 required: String
             val new_v: V = new_v1 + new_v2
                                     ^
<console>:84: error: type mismatch;
 found   : V
 required: String
           buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }

我做错了什么?

编辑: 对于将其标记为 Spark UDAF - using generics as input type? 重复的人 - 这不是该问题的重复,因为该问题不处理 Map 数据类型.对于使用 Map 数据类型所面临的问题,上面的代码非常具体和完整。

【问题讨论】:

  • 你为什么假设类型V 有一个+ 操作符(方法)?您没有将它绑定到任何特定的东西,因此它可以是任何类,包括未定义此运算符的类。您希望将V 绑定为任何数字 类型吗?
  • @TzachZohar 看起来这个错误与我做加法的方式有关?在您的CombineMaps 示例中(这真的很棒!顺便说一句),我试图摆脱merge (因为我只需要在我的用例中添加)。声明 val result = map1 ++ map2.map{case(k,v) =&gt; k -&gt; map1.get(k).map(v + _).getOrElse(v) } 正在抛出与上述完全相同的错误!
  • 这就是为什么需要merge 参数的原因:对于泛型类型V,UDAF 怎么知道如何将两个值合并为一个?对于数值,+ 是一个不错的选择,但对于没有定义 + 运算符的非数值 - 您需要调用者提供匹配函数。无论如何 - 你从@user8371915 那里得到了一个很好的答案

标签: scala apache-spark user-defined-aggregate


【解决方案1】:

将类型限制为具有Numeric[_] 的类型

class mySumToMap[K, V: Numeric](keyType: DataType, valueType: DataType) 
  extends UserDefinedAggregateFunction {
    ...

使用Implicitly 在运行时获取它:

val n = implicitly[Numeric[V]]

并使用其plus 方法代替+zero 代替0

buffer1(0) = map1 ++ map2.map{ 
  case (k,v) => k -> n.plus(v,  map1.getOrElse(k, n.zero))
}

要支持更广泛的类型,您可以使用cats Monoid

import cats._
import cats.implicits._

并调整代码:

class mySumToMap[K, V: Monoid](keyType: DataType, valueType: DataType) 
  extends UserDefinedAggregateFunction {
    ...

及以后:

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
  val map1: Map[K, V] = buffer1.getMap[K, V](0)
  val map2: Map[K, V] = buffer2.getMap[K, V](0)

  val m = implicitly[Monoid[Map[K, V]]]

  buffer1(0) = m.combine(map1, map2)
}

【讨论】:

    猜你喜欢
    • 2023-03-21
    • 1970-01-01
    • 1970-01-01
    • 2020-10-18
    • 2010-09-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多