【问题标题】:Spark RDD Or SQL operations to compute conditional counts用于计算条件计数的 Spark RDD 或 SQL 操作
【发布时间】:2018-08-17 11:57:21
【问题描述】:

作为背景知识,我正在尝试在 Spark 中实现 Kaplan-Meier。特别是,我假设我有一个数据框/集,其中 Double 列表示为 DataInt 列名为 censorFlag0 值如果被审查,1 如果没有,更喜欢这个Boolean 类型)。

例子:

val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)] 

现在我需要计算一个列 wins 来计算每个 data 值的实例。我使用以下代码实现了这一点:

val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))

当我需要计算一个称为 atRisk 的数量时,问题就来了,对于 data 的每个值,大于或等于它的 data 点的数量(累积过滤计数,如果你会的)。

以下代码有效:

// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect 
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show

但是,用例涉及从列data 派生bins 本身,我宁愿将其保留为单列数据集(或最坏的 RDD),但肯定不是本地数组。但这不起作用:

// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show

这也不是:

// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show

另一种确实有效但从并行化角度来看也不令人满意的方法是使用Window

// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation." 
atRiskCounts.show

最后一个解决方案没有用,因为它最终将我的数据洗牌到单个分区(在这种情况下,我不妨选择选项 1 工作)。

成功的方法很好,只是箱子不平行,如果可能的话,我真的很想保留这些东西。我查看了 groupBy 聚合、pivot 类型的聚合,但似乎没有任何意义。

我的问题是:有没有办法以分布式方式计算 atRisk 列?另外,为什么我在失败的解决方案中得到NullPointerException

编辑每条评论

我最初并没有发布NullPointerException,因为它似乎没有包含任何有用的内容。我会注意,这是通过自制软件在我的 Macbook Pro 上安装的 Spark(Spark 版本 2.2.1,独立 localhost 模式)。

                18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
            java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
                at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
                at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
                at java.lang.Thread.run(Thread.java:748)
            18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
            org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $anonfun$1.apply(<console>:33)
                at $anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            Driver stacktrace:
              at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1505)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1504)
              at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
              ... 50 elided
            Caused by: java.lang.NullPointerException
              at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
              at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
              at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
              at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
              at $anonfun$1.apply(<console>:33)
              at $anonfun$1.apply(<console>:33)
              at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
              at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
              at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
              at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
              at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
              at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
              at org.apache.spark.scheduler.Task.run(Task.scala:108)
              at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
              at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
              at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
              at java.lang.Thread.run(Thread.java:748)

我的最佳猜测是 df("data").geq(x).count 行可能是 barfs 的部分,因为并非每个节点都可能有 x 并因此是空指针?

【问题讨论】:

  • 你能发布NullPointerException 的堆栈跟踪吗?我没有立即发现任何东西(并不是说其他​​人不会),但堆栈跟踪可能有助于缩小范围。

标签: scala apache-spark spark-dataframe rdd


【解决方案1】:

我没有测试过这个,所以语法可能很傻,但我会做一系列的连接:

我相信你的第一个语句相当于这个——对于每个data 值,计算有多少wins

val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))

然后,正如您所指出的,我们可以构建垃圾箱的数据框:

val distinctData = df.select($"data".as("dataBins")).distinct()

然后加入 &gt;= 条件:

val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
  .groupBy($"data", $"wins")
  .count()

【讨论】:

  • 这几乎是我需要的,但就像@Ramesh 的解决方案一样,它是倒退的——你需要翻转连接才能对数据进行计数on ,而不是垃圾箱。 IE。该行应该是:val atRiskCounts = bins.join(df, bins("dataBins") &lt;= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")。但这在其他方面就像一个魅力!我还没有测试过这个与其他解决方案的比较,但现在这似乎是正确的。将通过更大的测试进行更新并很快接受。
【解决方案2】:

当您需要检查列中的值以及该列中的所有其余值时,收集是最重要的。当需要检查所有值时,可以确定该列的所有数据都需要在一个执行器或驱动程序中累积。当有你的需求时,你无法避免这一步。

现在主要部分是如何定义其余步骤以从 Spark 的并行化中受益。我建议您 broadcast 收集的集合 (因为它只有一列的不同数据,所以它们不能很大)并使用 udf 函数来检查 gte 条件,如下所示

首先你可以优化你的收集步骤

import org.apache.spark.sql.functions._
val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]

那你broadcast收藏集

val broadcastedArray = sc.broadcast(collectedData)

下一步是定义udf 函数并检查gte 条件并返回counts

def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))

并将其用作

distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)

所以最后你应该有

+----+----------+----+------+
|data|censorFlag|wins|atRisk|
+----+----------+----+------+
|4.5 |1         |1   |1     |
|0.7 |0         |0   |6     |
|2.3 |0         |0   |3     |
|1.0 |1         |1   |4     |
|0.8 |1         |2   |5     |
|0.8 |1         |2   |5     |
|4.0 |1         |1   |2     |
+----+----------+----+------+

我希望这是必需的dataframe

【讨论】:

  • 感谢您的建议!虽然这也是我探索过的东西(但忘了提及),但这实际上与我想要的相反,因为我希望atRisk 列计算data 实例有多少geq 比每个点在broadcastedArray 中,而不是相反(即值0.7 将有7 与你的6 相比,因为你错过了额外的0.8)。我正在考虑如何将其反转,但看起来@hoyland 的响应似乎是最简单的。
  • 这 6 个有风险的 0.7 数据值是因为我作为集合收集并且在一个集合中重复被计为一个。相反,如果您使用 collect_array 那么您将获得 7 。你明白了吗?如果您按照 hoyland 的建议寻求简单的解决方案,那就去做吧。选择永远是你的。但要小心,有多个 groupBy 和聚合以及一个连接也需要大量的改组。所以在你选择之前总是测试它们。你自己也可能找到更好的解决方案:)
  • 该修复的问题在于您实际上是在广播整个数据列,这会很昂贵。但是,使用广播的不同值数组,您可以使用我的“失败解决方案”行:broadcastedArray.map(x =&gt; (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk") 这完全符合预期,所以谢谢!现在我只需要弄清楚在更大的数据集上哪个更快。稍后会发布比较。
【解决方案3】:

我尝试了上面的例子(虽然不是最严格的!),似乎左边的join 总体上效果最好。

数据:

import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))

连接示例:

def runJoin(sc: SparkContext, df:DataFrame): Unit = {
  val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
  val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
  val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
  val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
  finalDF.show
}

广播示例:

def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
  val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
  val binsBroadcast = sc.broadcast(bins)
  val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
  val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
  finalDF.show
  binsBroadcast.destroy
}

以及测试代码:

var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

我为不同大小的随机数据运行此代码,提供手动 bins 数组(一些非常精细,50% 的原始不同数据,一些非常小,10% 的原始不同数据),并且始终如一地看起来join 方法是最快的(虽然两者都达到了相同的解决方案,所以这是一个加号!)。

平均而言,我发现 bin 数组越小,broadcast 方法的效果就越好,但 join 似乎并没有受到太大影响。如果我有更多时间/资源来测试这个,我会运行大量模拟来查看平均运行时间,但现在我会接受@hoyland 的解决方案。

仍然不确定为什么原来的方法不起作用,所以对 cme​​ts 开放。

请让我知道我的代码中的任何问题或改进!谢谢你们两个:)

【讨论】:

    猜你喜欢
    • 2016-10-22
    • 2019-07-08
    • 1970-01-01
    • 2019-02-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多