【发布时间】:2020-07-23 05:13:10
【问题描述】:
我正在研究一个复杂的逻辑,我需要将数量从一个数据集重新分配到另一个数据集。
这个问题是this question的延续
在下面的示例中,我将介绍几个新维度。在汇总和分配所有数量后,我期望总数量相同,但我有一些差异。
请看下面的例子
package playground
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, round, sum}
object sample3 {
val spark = SparkSession
.builder()
.appName("Sample app")
.master("local")
.getOrCreate()
val sc = spark.sparkContext
final case class Owner(a: Long,
b: String,
c: Long,
d: Short,
e: String,
f: String,
o_qtty: Double)
// notice column d is not present in Invoice
final case class Invoice(c: Long,
a: Long,
b: String,
e: String,
f: String,
i_qtty: Double)
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.OFF)
import spark.implicits._
val ownerData = Seq(
Owner(11, "A", 666, 2017, "x", "y", 50),
Owner(11, "A", 222, 2018, "x", "y", 20),
Owner(33, "C", 444, 2018, "x", "y", 20),
Owner(33, "C", 555, 2018, "x", "y", 120),
Owner(22, "B", 555, 2018, "x", "y", 20),
Owner(99, "D", 888, 2018, "x", "y", 100),
Owner(11, "A", 888, 2018, "x", "y", 100),
Owner(11, "A", 666, 2018, "x", "y", 80),
Owner(33, "C", 666, 2018, "x", "y", 80),
Owner(11, "A", 444, 2018, "x", "y", 50),
)
val invoiceData = Seq(
Invoice(444, 33, "C", "x", "y", 10),
Invoice(999, 22, "B", "x", "y", 200),
Invoice(666, 11, "A", "x", "y", 15),
Invoice(555, 22, "B", "x", "y", 200),
Invoice(888, 11, "A", "x", "y", 12),
)
val owners = spark
.createDataset(ownerData)
.as[Owner]
.cache()
val invoices = spark
.createDataset(invoiceData)
.as[Invoice]
.cache()
val p1 = owners
.join(invoices, Seq("a", "c", "e", "f", "b"))
.selectExpr(
"a",
"d",
"b",
"e",
"f",
"c",
"IF(o_qtty-i_qtty < 0,o_qtty,o_qtty - i_qtty) AS qtty",
"IF(o_qtty-i_qtty < 0,0,i_qtty) AS to_distribute"
)
val p2 = owners
.join(invoices, Seq("a", "c", "e", "f", "b"), "left_outer")
.filter(row => row.anyNull)
.drop(col("i_qtty"))
.withColumnRenamed("o_qtty", "qtty")
val distribute = p1
.groupBy("a", "d", "b", "e", "f")
.agg(sum(col("to_distribute")).as("to_distribute"))
val proportion = p2
.groupBy("a", "d", "b", "e", "f")
.agg(sum(col("qtty")).as("proportion"))
val result = p2
.join(distribute, Seq("a", "d", "b", "e", "f"))
.join(proportion, Seq("a", "d", "b", "e", "f"))
.withColumn(
"qtty",
round(
((col("to_distribute") / col("proportion")) * col("qtty")) + col(
"qtty"
),
2
)
)
.drop("to_distribute", "proportion")
.union(p1.drop("to_distribute"))
result.show(false)
result.selectExpr("SUM(qtty)").show()
owners.selectExpr("SUM(o_qtty)").show()
/*
+---+----+---+---+---+---+-----+
|a |d |b |e |f |c |qtty |
+---+----+---+---+---+---+-----+
|11 |2018|A |x |y |222|27.71|
|33 |2018|C |x |y |555|126.0|
|33 |2018|C |x |y |666|84.0 |
|11 |2018|A |x |y |444|69.29|
|11 |2017|A |x |y |666|35.0 |
|33 |2018|C |x |y |444|10.0 |
|22 |2018|B |x |y |555|20.0 |
|11 |2018|A |x |y |888|88.0 |
|11 |2018|A |x |y |666|65.0 |
+---+----+---+---+---+---+-----+
+---------+
|sum(qtty)|
+---------+
| 525.0|
+---------+
+-----------+
|sum(o_qtty)|
+-----------+
| 640.0|
+-----------+
*/
}
}
另外,请注意聚合不能产生任何负数。
【问题讨论】:
-
嗨@Michael,我测试了你的代码,我认为它与你的数据样本一起工作正常,或者我遗漏了一些东西。试试
result.selectExpr("SUM(qtty)").show(),我得到570。 -
嗨@Chema 谢谢 - 我已经编辑了更多列的示例以反映问题。如您所见,分配后的总量不同。另请注意,“d”列在 Invoice 中不存在,但在 Owner 中 - 我想我需要引入一些窗口功能,但我不确定具体如何操作。再次感谢您的大力支持
-
聚合规则是否已更改或保持不变?如果发票中不存在
d,它是如何计算的?最后,在不同年份有两个A 666,它是如何计算的?因为Invoice没有year类别。我需要知道那些业务规则。 -
聚合规则保持不变,对于
Owners包含多个年份的情况(d列)-“待分配”的数量必须按照每年的数量进行分配。它是分布内部的一种分布。谢谢 -
嗨@Michael,好的,我想我明白了,但是对于发票表中不存在的
Owner(99, "D", 888, 2018, "x", "y", 100),它是如何计算的?
标签: scala apache-spark join aggregate