【问题标题】:Scala, tail recursion vs. non tail recursion, why is tail recursion slower?Scala,尾递归与非尾递归,为什么尾递归更慢?
【发布时间】:2023-09-10 04:22:01
【问题描述】:

我在向一位朋友解释说,我希望 Scala 中的非尾递归函数比尾递归函数慢,所以我决定验证一下。 我以两种方式编写了一个很好的旧阶乘函数并尝试比较结果。代码如下:

def main(args: Array[String]): Unit = {
  val N = 2000 // not too much or else *s
  var spent1: Long = 0
  var spent2: Long = 0
  for ( i <- 1 to 100 ) { // repeat to average the results
    val t0 = System.nanoTime
    factorial(N)
    val t1 = System.nanoTime
    tailRecFact(N)
    val t2 = System.nanoTime
    spent1 += t1 - t0
    spent2 += t2 - t1
  }
  println(spent1/1000000f) // get milliseconds
  println(spent2/1000000f)
}

@tailrec
def tailRecFact(n: BigInt, s: BigInt = 1): BigInt = if (n == 1) s else tailRecFact(n - 1, s * n)

def factorial(n: BigInt): BigInt = if (n == 1) 1 else n * factorial(n - 1)

结果让我很困惑,我得到这样的输出:

578.2985

870.22125

意思是非尾递归函数比尾递归函数快30%,而且运算次数是一样的!

如何解释这些结果?

【问题讨论】:

    标签: scala tail-recursion


    【解决方案1】:

    这实际上不是你首先看的地方。原因在于你的尾递归方法,你正在用它的乘法做更多的工作。尝试在递归调用中交换参数 n 和 s 的顺序,它会均匀。

    def tailRecFact(n: BigInt, s: BigInt): BigInt = if (n == 1) s else tailRecFact(n - 1, n * s)
    

    此外,此示例中的大部分时间都用于 BigInt 操作,这使递归调用的时间相形见绌。如果我们将这些转换为 Ints(编译为 Java 原语),那么您可以看到尾递归(goto)与方法调用的比较。

    object Test extends App {
    
      val N = 2000
    
      val t0 = System.nanoTime()
      for ( i <- 1 to 1000 ) {
        factorial(N)
      }
      val t1 = System.nanoTime
      for ( i <- 1 to 1000 ) {
        tailRecFact(N, 1)
      }
      val t2 = System.nanoTime
    
      println((t1 - t0) / 1000000f) // get milliseconds
      println((t2 - t1) / 1000000f)
    
      def factorial(n: Int): Int = if (n == 1) 1 else n * factorial(n - 1)
    
      @tailrec
      final def tailRecFact(n: Int, s: Int): Int = if (n == 1) s else tailRecFact(n - 1, s * n)
    }
    
    95.16733
    3.987605
    

    感兴趣的,反编译的输出

      public final scala.math.BigInt tailRecFact(scala.math.BigInt, scala.math.BigInt);
        Code:
           0: aload_1       
           1: iconst_1      
           2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
           5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
           8: ifeq          13
          11: aload_2       
          12: areturn       
          13: aload_1       
          14: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
          17: iconst_1      
          18: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
          21: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
          24: aload_1       
          25: aload_2       
          26: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
          29: astore_2      
          30: astore_1      
          31: goto          0
    
      public scala.math.BigInt factorial(scala.math.BigInt);
        Code:
           0: aload_1       
           1: iconst_1      
           2: invokestatic  #16                 // Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
           5: invokestatic  #20                 // Method scala/runtime/BoxesRunTime.equalsNumObject:(Ljava/lang/Number;Ljava/lang/Object;)Z
           8: ifeq          21
          11: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
          14: iconst_1      
          15: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
          18: goto          40
          21: aload_1       
          22: aload_0       
          23: aload_1       
          24: getstatic     #26                 // Field scala/math/BigInt$.MODULE$:Lscala/math/BigInt$;
          27: iconst_1      
          28: invokevirtual #30                 // Method scala/math/BigInt$.int2bigInt:(I)Lscala/math/BigInt;
          31: invokevirtual #36                 // Method scala/math/BigInt.$minus:(Lscala/math/BigInt;)Lscala/math/BigInt;
          34: invokevirtual #47                 // Method factorial:(Lscala/math/BigInt;)Lscala/math/BigInt;
          37: invokevirtual #39                 // Method scala/math/BigInt.$times:(Lscala/math/BigInt;)Lscala/math/BigInt;
          40: areturn   
    

    【讨论】:

    • 请详细说明。我没有看到“两倍”的调用?
    • 为什么需要 2 次隐式转换? 's' 和 'n' 已经是 BigInt。在“阶乘”中,您有 nfactorial(n-1),因此您有 1 的乘法和减法。即每个递归步骤有 2 个 BigInt 操作。在 tailRecFact 中,您有 n-1 和 ns,这也是每个递归步骤的 2 个 BigInt 操作。
    • 仅供参考,我删除了我的评论,因为我没有正确解释我的意思。我会更新我的答案。
    • 更新了,其实完全出乎我的意料。
    • @monkjack,你知道为什么交换 n 和 s 会如此影响时间吗?在字节码中没有太大区别,只是交换了两条aload指令。
    【解决方案2】:

    除了@monkjack 显示的问题(即乘小 * 大比大 * 小快,这确实占了更大的差异),您的算法在每种情况下都是不同的,所以它们并不是真的可比。

    在尾递归版本中,您将大到小相乘:

    n * n-1 * n-2 * ... * 2 * 1
    

    在非尾递归版本中,您将小到大相乘:

    n * (n-1 * (n-2 * (... * (2 * 1))))
    

    如果你改变尾递归版本,让它从小到大:

    def tailRecFact2(n: BigInt) = {
      def loop(x: BigInt, out: BigInt): BigInt =
        if (x > n) out else loop(x + 1, x * out)
      loop(1, 1)
    }
    

    那么尾递归比普通递归快大约 20%,而不是像你只进行 Monkjack 的修正那样慢 10%。这是因为将小的 BigInts 相乘比将大的 BigInts 相乘要快。

    【讨论】:

    • 这对我来说有点奇怪。是因为您可以将较小的 BigInts 保留在缓存中吗?
    • 这个答案有助于解释 BigInteger 的奇/差性能:*.com/a/17590529/770361