下面是一个有效的 Spark 代码(至少在你的例子中它给出了正确的结果:
由于 2 个笛卡尔积,代码效率不高。
区间比较的条件也可能需要注意:)
请随时改进代码并在此处发布改进后的答案。
import org.apache.spark.{SparkConf, SparkContext}
object Main {
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("")
val sc = new SparkContext(conf)
case class Interval(start : Double, end : Double)
def main(args: Array[String]): Unit = {
sc.setLogLevel("ERROR")
val input = List(Interval(1, 7), Interval(2, 3), Interval(6, 8))
val infinities = List(Double.NegativeInfinity, Double.PositiveInfinity)
val inputRdd = sc.parallelize(input)
val infinitiesRdd = sc.parallelize(infinities)
// Get unique flat boundary values e.g.: Interval(1, 7) will give 2 boundary values: [1, 7]
val boundaries = inputRdd.flatMap(v => List(v.start, v.end)).distinct()
// Additionally we will need negative and positive infinities
val all_boundaries = boundaries.union(infinitiesRdd)
// Calculate all intervals
val intervals = all_boundaries
// For each interval start get all possible interval ends
.cartesian(all_boundaries) // [(1, 2), (1, 3), (1, 6), (2, 1), ...]
// Filter out invalid intervals (where begin is either less or equal to the end) e.g.: from previous comment (2, 1) is invalid interval
.filter(v => v._1 < v._2) // [(1, 2), (1, 3), (1, 6), (2, 3), ...]
// Find lesser interval end e.g.: in previous comment (1, 2) -> 2 is smallest value for the same start (1)
.reduceByKey((a, b) => Math.min(a, b)) // [(1, 2) (2, 3), ...]
// Uncommend this to print intermediate result
// intervals.sortBy(_._1).collect().foreach(println)
// Get counts of overlapping intervals
val countsPerInterval = intervals
// for each small interval get all possible intput intervals e.g.:
.cartesian(inputRdd) // [((1, 2), Interval(1, 7)), ((1, 2), Interval(2, 3)), ...]
// Filter out intervals that do not overlap
.filter{ case (smallInterval, inputInterval) => inputInterval.start <= smallInterval._1 && inputInterval.end >= smallInterval._2} // [((1, 2), Interval(1, 7)), ((1, 2), Interval(2, 3)), ...]
// Since we're not interested in intervals, but only in count of intervals -> change interval to 1 for reduction
.mapValues(_ => 1) //[((1, 2), 1), ((1, 2), 1), ...]
// Calculate a sum per interval
.reduceByKey(_ + _) // [((1, 2), 2), ...]
// print result
countsPerInterval.sortBy(_._1).collect().foreach(println)
}
}