【问题标题】:How to have conditions inside aggregate function : Scala/Spark?如何在聚合函数中有条件:Scala/Spark?
【发布时间】:2016-10-25 04:19:16
【问题描述】:

我有一个 df = [CUSTOMER_ID ,itemType, eventTimeStamp, valueType, value ]

+-------------+-----------+-------------+--------+--------------------+
| CUSTOMER_ID | itemType  | valueType   | value  | eventTimeStamp     |
+-------------+-----------+--------+---------------------------------+
| 1           |  null     |  dvd        |  12    |2016-09-19T00:00:00Z
| 1           |  rent     |  dvd        |  12    |2016-09-19T00:00:00Z
| 1           |   buy     |  tv         |  12    |2016-09-20T00:00:00Z
| 1           |  rent     |  null       |  12    |2016-09-20T00:00:00Z
| 1           |   buy     |  movie      |  12    |2016-09-18T00:00:00Z
| 1           |   null    |  null       |  12    |2016-09-18T00:00:00Z
+-------------+-----------+-------------+--------+---------------------+ 

我想得到如下结果:

    CUSTOMER_ID : 1
    totalValue  : 72 --- group by based on id
    itemTypeMap : {"rent" : 2, "buy" : 2} --- group by based on id (without null)
    valueTypeMap : {"dvd" : 2, "tv" : 1, "movie" : 1 } --- group by based on id
    itemTypeForDay : {"rent: 2, "buy" : 2 }  --- group By based on id and dayofmonth(col("eventTimeStamp"))  atmost 1 type per day 

我的代码:

   val temp = df.groupBy("CUSTOMER_ID").agg(
          collectAsList(df("itemType")).alias("itemCount"),
          collectAsList(df("valueType")).alias("valueTypeCount"),
          sum("value") as "totalValues")


   val stage1  =  temp.withColumn("valueTypeMap", count_by_value(col("valueTypeCount")))
          .withColumn("itemTypeMap", count_by_value(col("itemCount")))
          .drop("itemCount")
         .drop("valueTypeCount")

   val toMap = udf { (typ: String, count: Int) => Map(typ -> count) }

   val count_by_value = udf {( value :scala.collection.mutable.WrappedArray[String]) => if (value == null) null else  value.groupBy(identity).mapValues(_.size)}

   val collectAsList = new CollectListFunction(StringType)


  import org.apache.spark.sql.functions.{dayofmonth, countDistinct}



   val stage2 = df.groupBy("CUSTOMER_ID", "itemType")
          .agg(countDistinct(dayofmonth(col("eventTimeStamp"))) as "daysPeritemType")
          .withColumn("itemTypeForDay", toMap(col("itemType"), col("daysPeritemType")))
          .groupBy("CUSTOMER_ID").agg(CombineMaps(col("itemTypeForDay")) as "resultMap")

   val result = stage1.join(stage2, stage1("CUSTOMER_ID") === stage2("CUSTOMER_ID"))
          .drop(stage2("CUSTOMER_ID"))

这给了我包含 null 的结果。在进行聚合时如何避免 null。我不想完全删除空行/列。只需要在对特定行进行聚合时避免它们。

实用类:

 case class Data(totalValue: Long, typeCount: Map[String,Int], typeForDay: Map[String,Int] ,itemCount : Map[String,Int]) extends Serializable

     def convertToRDD(result : DataFrame): RDD[(String, String)] = {

        val tempFile = result.map( {
          r => {
            val customerId = r.getAs[String]( "CUSTOMER_ID" )
            val totalValue = r.getAs[Long]( "totalValue" )

            val typeCount = r.getAs[Map[String, Int]]( "typeCount" )
            val itemCount = r.getAs[Map[String, Int]]( "itemCount" )
            val typeForDay = r.getAs[Map[String, Int]]( "typeForDay" )


            val features = Data( totalValue, typeCount, typeForDay, itemCount)

            val jsonString = JacksonUtil.toJson( features )

            (customerId, jsonString)
          }
        } )

        return tempFile
      }

class CollectListFunction[T] (val colType: DataType) extends UserDefinedAggregateFunction {

  def inputSchema: StructType =
    new StructType().add("inputCol", colType)

  def bufferSchema: StructType =
    new StructType().add("outputCol", ArrayType(colType))

  def dataType: DataType = ArrayType(colType)

  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
  }

  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val list = buffer.getSeq[T](0)
    if (!input.isNullAt(0)) {
      val sales = input.getAs[T](0)
      buffer.update(0, list:+sales)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getSeq[T](0) ++ buffer2.getSeq[T](0))
  }

  def evaluate(buffer: Row): Any = {
    buffer.getSeq[T](0)
  }
}



object CombineMaps extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("map", dataType)
  override def bufferSchema: StructType = inputSchema
  override def dataType: DataType = MapType(StringType, IntegerType)
  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0 , Map[String, Int]())

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val before = buffer.getAs[Map[String, Int]](0)
    val toAdd = input.getAs[Map[String, Int]](0)
    val result = before ++ toAdd
    buffer.update(0, result)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)

  override def evaluate(buffer: Row): Any = buffer.getAs[Map[String, Int]](0)
}

谁能告诉我我在这里做错了什么?

【问题讨论】:

    标签: scala apache-spark dataframe group-by aggregate


    【解决方案1】:

    据我了解您的要求,您可以过滤掉 udf 本身的值:

    udf {( value :scala.collection.mutable.WrappedArray[String]) => 
    if (value == null) null 
    else  value.filter(_!=null).groupBy(identity).mapValues(_.size)}
    

    【讨论】:

      猜你喜欢
      • 2021-04-07
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-04-07
      • 1970-01-01
      相关资源
      最近更新 更多