我建议跳过重新分区和排序步骤,直接跳转到分布式压缩归并排序(我刚刚为算法发明了名称,就像算法本身一样)。
这是算法中应该用作reduce操作的部分:
type Gap = (Int, Int)
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
它的工作原理如下:
println(mergeIntervals(
List((0, 3), (4, 7), (9, 11), (15, 16), (18, 22)),
List((1, 2), (4, 5), (6, 10), (12, 13), (15, 17))
))
// Outputs:
// List((0,3), (4,11), (12,13), (15,17), (18,22))
现在,如果将它与 Spark 的并行 reduce 结合使用,
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
你会得到类似并行合并排序的东西,它直接作用于整数序列的压缩表示(因此得名)。
结果:
+-----+-----+---+
| typ|start|end|
+-----+-----+---+
| Two| 10| 25|
| Two| 40| 45|
| One| 0| 8|
| One| 10| 15|
|Three| 30| 35|
+-----+-----+---+
讨论
mergeIntervals 方法实现了一个可交换的关联操作,用于合并已经按升序排序的非重叠区间列表。然后合并所有重叠的间隔,并再次按升序存储。可以在reduce 步骤中重复此过程,直到合并所有区间序列。
该算法的有趣特性是它最大程度地压缩了每个归约的中间结果。因此,如果您有很多重叠的区间,则该算法实际上可能更快,然后其他基于输入区间排序的算法。
但是,如果您有很多间隔很少重叠,那么这种方法可能会耗尽内存并且根本不起作用,因此必须使用其他算法首先对间隔进行排序,然后进行某种扫描并在本地合并相邻的区间。因此,这是否可行取决于用例。
完整代码
val df = Seq(
("One", 0, 5),
("One", 10, 15),
("One", 5, 8),
("Two", 10, 25),
("Two", 40, 45),
("Three", 30, 35)
).toDF("typ", "start", "end")
type Gap = (Int, Int)
/** The `merge`-step of a variant of merge-sort
* that works directly on compressed sequences of integers,
* where instead of individual integers, the sequence is
* represented by sorted, non-overlapping ranges of integers.
*/
def mergeIntervals(as: List[Gap], bs: List[Gap]): List[Gap] = {
require(!as.isEmpty, "as must be non-empty")
require(!bs.isEmpty, "bs must be non-empty")
// assuming that `as` and `bs` both are either lists with a single
// interval, or sorted lists that arise as output of
// this method, recursively merges them into a single list of
// gaps, merging all overlapping gaps.
@annotation.tailrec
def mergeRec(
gaps: List[Gap],
gapStart: Int,
gapEndAccum: Int,
as: List[Gap],
bs: List[Gap]
): List[Gap] = {
as match {
case Nil => {
bs match {
case Nil => (gapStart, gapEndAccum) :: gaps
case notEmpty => mergeRec(gaps, gapStart, gapEndAccum, bs, Nil)
}
}
case (a0, a1) :: at => {
if (a0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max a1, at, bs)
} else {
bs match {
case Nil => mergeRec((gapStart, gapEndAccum) :: gaps, a0, gapEndAccum max a1, at, bs)
case (b0, b1) :: bt => if (b0 <= gapEndAccum) {
mergeRec(gaps, gapStart, gapEndAccum max b1, as, bt)
} else {
if (a0 < b0) {
mergeRec((gapStart, gapEndAccum) :: gaps, a0, a1, at, bs)
} else {
mergeRec((gapStart, gapEndAccum) :: gaps, b0, b1, as, bt)
}
}
}
}
}
}
}
val (a0, a1) :: at = as
val (b0, b1) :: bt = bs
val reverseRes =
if (a0 < b0)
mergeRec(Nil, a0, a1, at, bs)
else
mergeRec(Nil, b0, b1, as, bt)
reverseRes.reverse
}
val mergedIntervals = df.
as[(String, Int, Int)].
rdd.
map{case (t, s, e) => (t, List((s, e)))}. // Convert start end to list with one interval
reduceByKey(mergeIntervals). // perform parallel compressed merge-sort
flatMap{ case (k, vs) => vs.map(v => (k, v._1, v._2))}.// explode resulting lists of merged intervals
toDF("typ", "start", "end") // convert back to DF
mergedIntervals.show()
测试
mergeIntervals 的实现经过了一点测试。如果你想真正将它合并到你的代码库中,这里至少是一个重复随机测试的草图:
def randomIntervalSequence(): List[Gap] = {
def recHelper(acc: List[Gap], open: Option[Int], currIdx: Int): List[Gap] = {
if (math.random > 0.999) acc.reverse
else {
if (math.random > 0.90) {
if (open.isEmpty) {
recHelper(acc, Some(currIdx), currIdx + 1)
} else {
recHelper((open.get, currIdx) :: acc, None, currIdx + 1)
}
} else {
recHelper(acc, open, currIdx + 1)
}
}
}
recHelper(Nil, None, 0)
}
def intervalsToInts(is: List[Gap]): List[Int] = is.flatMap{ case (a, b) => a to b }
var numNonTrivialTests = 0
while(numNonTrivialTests < 1000) {
val as = randomIntervalSequence()
val bs = randomIntervalSequence()
if (!as.isEmpty && !bs.isEmpty) {
numNonTrivialTests += 1
val merged = mergeIntervals(as, bs)
assert((intervalsToInts(as).toSet ++ intervalsToInts(bs)) == intervalsToInts(merged).toSet)
}
}
你显然必须用更文明的东西替换原始的assert,这取决于你的框架。