【发布时间】:2016-10-25 04:19:16
【问题描述】:
我有一个 df = [CUSTOMER_ID ,itemType, eventTimeStamp, valueType, value ]
+-------------+-----------+-------------+--------+--------------------+
| CUSTOMER_ID | itemType | valueType | value | eventTimeStamp |
+-------------+-----------+--------+---------------------------------+
| 1 | null | dvd | 12 |2016-09-19T00:00:00Z
| 1 | rent | dvd | 12 |2016-09-19T00:00:00Z
| 1 | buy | tv | 12 |2016-09-20T00:00:00Z
| 1 | rent | null | 12 |2016-09-20T00:00:00Z
| 1 | buy | movie | 12 |2016-09-18T00:00:00Z
| 1 | null | null | 12 |2016-09-18T00:00:00Z
+-------------+-----------+-------------+--------+---------------------+
我想得到如下结果:
CUSTOMER_ID : 1
totalValue : 72 --- group by based on id
itemTypeMap : {"rent" : 2, "buy" : 2} --- group by based on id (without null)
valueTypeMap : {"dvd" : 2, "tv" : 1, "movie" : 1 } --- group by based on id
itemTypeForDay : {"rent: 2, "buy" : 2 } --- group By based on id and dayofmonth(col("eventTimeStamp")) atmost 1 type per day
我的代码:
val temp = df.groupBy("CUSTOMER_ID").agg(
collectAsList(df("itemType")).alias("itemCount"),
collectAsList(df("valueType")).alias("valueTypeCount"),
sum("value") as "totalValues")
val stage1 = temp.withColumn("valueTypeMap", count_by_value(col("valueTypeCount")))
.withColumn("itemTypeMap", count_by_value(col("itemCount")))
.drop("itemCount")
.drop("valueTypeCount")
val toMap = udf { (typ: String, count: Int) => Map(typ -> count) }
val count_by_value = udf {( value :scala.collection.mutable.WrappedArray[String]) => if (value == null) null else value.groupBy(identity).mapValues(_.size)}
val collectAsList = new CollectListFunction(StringType)
import org.apache.spark.sql.functions.{dayofmonth, countDistinct}
val stage2 = df.groupBy("CUSTOMER_ID", "itemType")
.agg(countDistinct(dayofmonth(col("eventTimeStamp"))) as "daysPeritemType")
.withColumn("itemTypeForDay", toMap(col("itemType"), col("daysPeritemType")))
.groupBy("CUSTOMER_ID").agg(CombineMaps(col("itemTypeForDay")) as "resultMap")
val result = stage1.join(stage2, stage1("CUSTOMER_ID") === stage2("CUSTOMER_ID"))
.drop(stage2("CUSTOMER_ID"))
这给了我包含 null 的结果。在进行聚合时如何避免 null。我不想完全删除空行/列。只需要在对特定行进行聚合时避免它们。
实用类:
case class Data(totalValue: Long, typeCount: Map[String,Int], typeForDay: Map[String,Int] ,itemCount : Map[String,Int]) extends Serializable
def convertToRDD(result : DataFrame): RDD[(String, String)] = {
val tempFile = result.map( {
r => {
val customerId = r.getAs[String]( "CUSTOMER_ID" )
val totalValue = r.getAs[Long]( "totalValue" )
val typeCount = r.getAs[Map[String, Int]]( "typeCount" )
val itemCount = r.getAs[Map[String, Int]]( "itemCount" )
val typeForDay = r.getAs[Map[String, Int]]( "typeForDay" )
val features = Data( totalValue, typeCount, typeForDay, itemCount)
val jsonString = JacksonUtil.toJson( features )
(customerId, jsonString)
}
} )
return tempFile
}
class CollectListFunction[T] (val colType: DataType) extends UserDefinedAggregateFunction {
def inputSchema: StructType =
new StructType().add("inputCol", colType)
def bufferSchema: StructType =
new StructType().add("outputCol", ArrayType(colType))
def dataType: DataType = ArrayType(colType)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, new scala.collection.mutable.ArrayBuffer[T])
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val list = buffer.getSeq[T](0)
if (!input.isNullAt(0)) {
val sales = input.getAs[T](0)
buffer.update(0, list:+sales)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getSeq[T](0) ++ buffer2.getSeq[T](0))
}
def evaluate(buffer: Row): Any = {
buffer.getSeq[T](0)
}
}
object CombineMaps extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("map", dataType)
override def bufferSchema: StructType = inputSchema
override def dataType: DataType = MapType(StringType, IntegerType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0 , Map[String, Int]())
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val before = buffer.getAs[Map[String, Int]](0)
val toAdd = input.getAs[Map[String, Int]](0)
val result = before ++ toAdd
buffer.update(0, result)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
override def evaluate(buffer: Row): Any = buffer.getAs[Map[String, Int]](0)
}
谁能告诉我我在这里做错了什么?
【问题讨论】:
标签: scala apache-spark dataframe group-by aggregate