【问题标题】:Why is this simple haskell algorithm so slow?为什么这个简单的haskell算法这么慢?
【发布时间】:2011-12-28 17:45:53
【问题描述】:

剧透警告:这与来自 Project Euler 的 Problem 14 有关。

以下代码运行大约需要 15 秒。我有一个在 1 秒内运行的非递归 Java 解决方案。我想我应该能够让这段代码更接近那个。

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

我使用+RHS -p 进行了分析,并注意到分配的内存很大,并且随着输入的增长而增长。对于n = 100,000,分配了1gb(!),对于n = 1,000,000,分配了13gb(!!)。

再一次,-sstderr 表明,虽然分配了很多字节,但总内存使用量为 1mb,效率为 95% 以上,所以可能 13gb 是红鲱鱼。

我能想到几种可能性:

  1. 有些事情并不像它需要的那样严格。我已经发现了 foldl1',但也许我需要做更多?是否可以标记collatz 一样严格(这有意义吗?

  2. collatz 不是尾调用优化。我认为应该是但不要 知道如何确认。

  3. 编译器没有做一些我认为应该做的优化——例如 任何时候只有collatz 的两个结果需要在内存中(最大值和当前)

有什么建议吗?

这几乎是Why is this Haskell expression so slow? 的复制品,不过我会注意到快速Java 解决方案不需要执行任何记忆。有什么方法可以加快速度而不必求助于它?

作为参考,这是我的分析输出:

  Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)

     scratch +RTS -p -hc -RTS

  total time  =        5.12 secs   (256 ticks @ 20 ms)
  total alloc = 13,229,705,716 bytes  (excludes profiling overheads)

COST CENTRE                    MODULE               %time %alloc

collatz                        Main                  99.6   99.4


                                                                                               individual    inherited
COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc

MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
 CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
  collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
  main                   Main                                                 214           1   0.4    0.6   100.0  100.0
   collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
 CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
 CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
 CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
 CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0

和-sstderr:

./scratch +RTS -sstderr 
525
  21,085,474,908 bytes allocated in the heap
      87,799,504 bytes copied during GC
           9,420 bytes maximum residency (1 sample(s))          
          12,824 bytes maximum slop               
               1 MB total memory in use (0 MB lost due to fragmentation)  

  Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
  Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed

  INIT  time    0.00s  (  0.00s elapsed)
  MUT   time   35.38s  ( 36.37s elapsed)
  GC    time    0.40s  (  0.51s elapsed)
  RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
  EXIT  time    0.00s  (  0.00s elapsed)
  Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second

  Productivity  98.9% of total user, 95.9% of total elapsed

还有 Java 解决方案(不是我的,取自 Project Euler 论坛,删除了记忆):

public class Collatz {
  public int getChainLength( int n )
  {
    long num = n;
    int count = 1;
    while( num > 1 )
    {
      num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
      count++;
    }
    return count;
  }

  public static void main(String[] args) {
    Collatz obj = new Collatz();
    long tic = System.currentTimeMillis();
    int max = 0, len = 0, index = 0;
    for( int i = 3; i < 1000000; i++ )
    {
      len = obj.getChainLength(i);
      if( len > max )
      {
        max = len;
        index = i;
      }
    }
    long toc = System.currentTimeMillis();
    System.out.println(toc-tic);
    System.out.println( "Index: " + index + ", length = " + max );
  }
}

【问题讨论】:

  • 令人惊讶的是,GHC 没有像任何自尊的 C 编译器所期望的那样将 (quot n 2) 优化为 (rshift n 1)。有什么原因吗?
  • @solrize:我也很惊讶。

标签: haskell collatz


【解决方案1】:

起初,我认为您应该尝试在collatz 中的a 之前加一个感叹号:

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

(您需要将{-# LANGUAGE BangPatterns #-} 放在源文件的顶部才能使其正常工作。)

我的推理如下:问题是您在 collat​​z 的第一个参数中构建了一个巨大的 thunk:它以 1 开始,然后变为 1 + 1,然后变成(1 + 1) + 1, ... 所有这些都没有被强迫。这个bang pattern 强制collatz 的第一个参数在每次调用时都被强制执行,所以它从 1 开始,然后变成 2,依此类推,没有建立一个大的未评估的 thunk:它只是保持为整数。

请注意,爆炸模式只是使用seq 的简写;在这种情况下,我们可以重写 collatz 如下:

collatz a _ | seq a False = undefined
collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

这里的技巧是强制 a 在守卫中,然后它总是评估为 False (因此身体是无关紧要的)。然后评估继续下一个案例,a 已经被评估。但是,刘海模式更清晰。

不幸的是,当使用-O2 编译时,它的运行速度并不比原来的快!我们还能尝试什么?好吧,我们可以做的一件事是假设这两个数字永远不会溢出机器大小的整数,并给collatz这个类型注释:

collatz :: Int -> Int -> Int

我们将把 bang 模式留在那里,因为我们仍然应该避免建立 thunk,即使它们不是性能问题的根源。这将我的(慢速)计算机上的时间缩短到 8.5 秒。

下一步是尝试使其更接近 Java 解决方案。首先要意识到的是,在 Haskell 中,div 在数学上以更正确的方式处理负整数,但比“正常”C 除法慢,在 Haskell 中称为 quot。将 div 替换为 quot 可将运行时间缩短至 5.2 秒,将 x `quot` 2 替换为 x `shiftR` 1(导入 Data.Bits)以匹配 Java 解决方案可将运行时间缩短至 4.9 秒。

这是我目前所能得到的最低值,但我认为这是一个相当不错的结果;由于您的计算机比我的要快,它应该更接近 Java 解决方案。

这是最终代码(我在途中做了一些清理工作):

{-# LANGUAGE BangPatterns #-}

import Data.Bits
import Data.List

collatz :: Int -> Int
collatz = collatz' 1
  where collatz' :: Int -> Int -> Int
        collatz' !a 1 = a
        collatz' !a x
          | even x    = collatz' (a + 1) (x `shiftR` 1)
          | otherwise = collatz' (a + 1) (3 * x + 1)

main :: IO ()
main = print . foldl1' max . map collatz $ [1..1000000]

查看该程序的 GHC 核心(使用ghc-core),我认为这可能已经差不多了; collatz 循环使用未装箱的整数,程序的其余部分看起来没问题。我能想到的唯一改进是从 map collatz [1..1000000] 迭代中消除拳击。

顺便说一句,不要担心“total alloc”这个数字;它是在程序的整个生命周期内分配的总内存,即使 GC 回收该内存,它也不会减少。数 TB 的数字很常见。

【讨论】:

  • 谢谢,这真的很有帮助。我不知道-O2,这有很大的不同(将运行时间降低到 5 秒)。为问题添加了 Java 解决方案。
  • 哦,我假设你已经在使用-O2,因为修改后的带有爆炸模式的程序在我的机器上运行了 16 秒 :) 我会看看你的 Java 解决方案。跨度>
  • 对了,Java程序其实以x = 3开头,但是对性能的影响可以忽略不计,感觉像作弊,所以我没有做Haskell程序也这样做:)
  • shiftR 将运行时间缩短至 1.5 秒。我对此很满意!再次感谢。
  • It turns out 你只能在 64 位机器上使用Int(或者实际上是Word)。该序列将在 32 位上溢出。只是说这个,为了这里未来可能的读者的利益。 :)
【解决方案2】:

您可能会丢失列表和 bang 模式,但改用堆栈仍然可以获得相同的性能。

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

这样做的一个问题是,对于大输入,您将不得不增加堆栈的大小,例如 1_000_000。

更新:

这是一个没有堆栈溢出问题的尾递归版本。

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

注意使用Word 而不是Int。它在性能上有所不同。如果需要,您仍然可以使用 bang 模式,这将使性能几乎翻倍。

【讨论】:

    【解决方案3】:

    我发现有一件事对这个问题产生了惊人的影响。我坚持直接的递归关系而不是折叠,你应该原谅这个表达,用它来计数。重写

    collatz n = if even n then n `div` 2 else 3 * n + 1
    

    作为

    collatz n = case n `divMod` 2 of
                (n', 0) -> n'
                _       -> 3 * n + 1
    

    在具有 2.8 GHz Athlon II X4 430 CPU 的系统上,我的程序的运行时间缩短了 1.2 秒。我最初的更快版本(使用 divMod 后 2.3 秒):

    {-# LANGUAGE BangPatterns #-}
    
    import Data.List
    import Data.Ord
    
    collatzChainLen :: Int -> Int
    collatzChainLen n = collatzChainLen' n 1
        where collatzChainLen' n !l
                | n == 1    = l
                | otherwise = collatzChainLen' (collatz n) (l + 1)
    
    collatz:: Int -> Int
    collatz n = case n `divMod` 2 of
                     (n', 0) -> n'
                     _       -> 3 * n + 1
    
    pairMap :: (a -> b) -> [a] -> [(a, b)]
    pairMap f xs = [(x, f x) | x <- xs]
    
    main :: IO ()
    main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))
    

    一个可能更惯用的 Haskell 版本在大约 9.7 秒内运行(8.5 与 divMod);这是相同的保存

    collatzChainLen :: Int -> Int
    collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n
    

    使用 Data.List.Stream 应该允许流融合,这将使这个版本运行起来更像是显式积累,但我找不到具有 Data.List.Stream 的 Ubuntu libghc* 包,所以我还无法验证。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-01-10
      • 2012-06-13
      • 1970-01-01
      • 1970-01-01
      • 2015-02-26
      • 2020-09-06
      • 2019-06-10
      • 1970-01-01
      相关资源
      最近更新 更多