【问题标题】:Free ~> Trampoline : recursive program crashes with OutOfMemoryError免费 ~> Trampoline : 递归程序因 OutOfMemoryError 而崩溃
【发布时间】:2016-12-13 14:01:09
【问题描述】:

假设我试图用一个操作实现一种非常简单的领域特定语言:

printLine(line)

然后我想编写一个程序,它以整数 n 作为输入,如果 n 可被 10k 整除,则打印一些内容,然后使用 n + 1 调用自身,直到 n 达到某个最大值 @987654327 @。

省略所有由 for-comprehensions 引起的句法噪音,我想要的是:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

本质上,这将是一种“嘶嘶声”。

这里有一些尝试使用 Scalaz 7.3.0-M7 中的 Free monad 来实现这一点:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

不幸的是,直接翻译 (p0) 似乎以某种 O(N^2) 开销运行,并因 OutOfMemoryError 而崩溃。问题似乎是for-comprehension 在递归调用p0 之后附加了一个map{x =&gt; ()},这迫使Free monad 用提醒填充整个内存“完成'p0'然后什么也不做”。 如果我手动“展开”for 理解,并明确写出最后一个flatMap(如p3p4),那么问题就消失了,一切顺利进行。然而,这是一个非常脆弱的解决方法:如果我们简单地将map(id) 附加到它,程序的行为会发生巨大变化,而这个map(id) 在代码中甚至不可见,因为它是由@ 自动生成的987654341@-理解。

在此较早的帖子中:https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ 反复建议将递归调用包装到suspend 中。这是Applicative 实例和suspend 的尝试:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

插入suspend 并没有真正帮助:它仍然很慢,并且与OutOfMemoryErrors 一起崩溃。

我应该以不同的方式使用suspend吗?

也许有一些纯粹的句法补救措施可以使用 for-comprehensions 而不最终生成 map

如果有人能指出我在这里做错了什么以及如何修复它,我将不胜感激。

【问题讨论】:

  • 嗨,我复制并运行了你的代码,它既不慢,也没有得到OutOfMemory。当我将 N 增加十倍时,与天真的 tailrec 解决方案(你有 O(N) 的地方)相比,它变得更慢(这是可以预料的,因为你应该得到 O(N*N)),但仍然没有 OOM 错误。
  • 这可能取决于 JVM 设置和硬件。如果您没有立即看到效果,请尝试使用类似 n = 1000000, N = 10.000.000 的值。在我的笔记本电脑上,一些程序的运行速度明显变慢,并且在 N = 5000000 时出现 OutOfMemory 失败。但是您应该会看到 N 值较小时速度变慢。

标签: scala out-of-memory scalaz tail-recursion free-monad


【解决方案1】:

Scala 编译器添加的多余的map 将递归从尾部位置移动到非尾部 位置。 Free monad 仍然使这个堆栈安全,但空间复杂度变为 O(N) 而不是 O(1)。 (具体来说还是不是O(N2)。)

是否有可能使scalac 优化map 离开会产生一个单独的问题(我不知道答案)。

我将尝试说明在解释 p1p3 时发生了什么。 (我将忽略对Trampoline 的翻译,这是多余的(见下文)。)

p3(即没有额外的map

让我使用以下简写:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

现在p3(0)解释如下

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

等等...您会看到在任何时候所需的内存量都不会超过某个恒定的上限。

p1(即额外的map

我将使用以下速记:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

现在p1(0) 解释如下:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

等等...您会看到内存消耗线性依赖于N。我们只是将评估从堆栈移动到堆。

带走:为了保持Free 对内存友好,将递归保持在“尾部位置”,即flatMap(或map)的右侧。

旁白: 不需要翻译成Trampoline,因为Free 已经被蹦床了。您可以直接解释为Id 并使用foldMapRec 进行堆栈安全解释:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

这将为您恢复一些记忆(但不会让问题消失)。

【讨论】:

  • 非常感谢您的详细说明,它证实了我的直觉,即 p0 在运行时会在内存中留下 O(N)-trail。我不确定时间开销:对于一些较旧的 Free 实现,使用 F:Functor 将下一个操作附加到类似链表的结构的末尾,我可以想象它实际上可能是 O(N^ 2),但我得看看当前的实现,再考虑一下。
  • 在“旁白”:我使用Trampoline 只是为了说明,我可能会使用其他东西作为“解释目标”。推理“不需要翻译成 Trampoline,因为 Free 已经被 trampolined”似乎与这个问题的答案不同:stackoverflow.com/questions/29660067/…
  • 关于“优化map away”:可能类似于def noMap[X](x: X) = new { def map(f: Unit =&gt; Unit): X = x },然后将最后一个生成器包装在noMap 中?它可以工作,并消除了由for 生成的最后一个map,但如果它已经存在于某个地方(在Scalaz 或其他地方),那么使用更传统的东西会更好。
  • @AndreyTyukin 是的,有一个 Free 的实现依赖于 Functor,并且受到二次时间复杂度的影响。从那以后,情况有所改善。您链接到的答案是 pre-foldMapRec。我没有强调foldMapRec 在堆栈安全解释中对Id 的重要性。这是noMap 的一个巧妙技巧!我不知道“标准”解决方案,但如果您找到一个,我很想听听。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-03-29
  • 1970-01-01
  • 1970-01-01
  • 2023-03-14
相关资源
最近更新 更多