【问题标题】:Memoization in Haskell?Haskell中的记忆?
【发布时间】:2011-03-13 14:28:44
【问题描述】:

关于如何在 Haskell 中有效解决以下函数的任何指针,对于大数 (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

我在 Haskell 中看到过用于解决斐波那契问题的记忆示例 数字,涉及(懒惰地)计算所有斐波那契数 直到所需的 n。但在这种情况下,对于给定的 n,我们只需要 计算很少的中间结果。

谢谢

【问题讨论】:

  • 仅在某种意义上说这是我在家做的一些工作:-)

标签: haskell memoization


【解决方案1】:

不是最有效的方法,但可以记忆:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

请求f !! 144时,检查f !! 143是否存在,但不计算其确切值。它仍然被设置为一些未知的计算结果。计算出的唯一准确值是需要的值。

所以最初,对于计算了多少,程序一无所知。

f = .... 

当我们发出请求f !! 12,它开始做一些模式匹配:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在开始计算

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

这递归地对f提出另一个要求,所以我们计算

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

现在我们可以涓流备份一些

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

这意味着程序现在知道了:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

继续涓涓细流:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

这意味着程序现在知道了:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算f!!6

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

这意味着程序现在知道了:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

现在我们继续计算f!!12

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

这意味着程序现在知道了:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

所以计算是相当懒惰的。程序知道f !! 8 存在某个值,它等于g 8,但它不知道g 8 是什么。

【讨论】:

  • 谢谢你的这个。您将如何创建和使用二维解决方案空间?那会是一个列表列表吗?和g n m = (something with) f!!a!!b
  • 当然可以。不过,对于一个真正的解决方案,我可能会使用一个记忆库,比如memocombinators
  • 不幸的是 O(n^2)。
【解决方案2】:

我们可以通过创建一个可以在亚线性时间内索引的结构来非常有效地做到这一点。

但首先,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

让我们定义f,但让它使用“开放递归”而不是直接调用自身。

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

您可以使用fix f 获得一个未记忆的f

这将让您测试 f 通过调用 f 的小值是否符合您的意思,例如:fix f 123 = 144

我们可以通过定义来记住这一点:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

它的表现还算不错,并且用可以记住中间结果的东西代替了将要花费 O(n^3) 时间的东西。

但是仅仅索引找到mf的记忆答案仍然需要线性时间。这意味着结果如下:

*Main Data.List> faster_f 123801
248604

是可以容忍的,但结果并没有比这更好。我们可以做得更好!

首先,让我们定义一棵无限树:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

然后我们将定义一种索引方式,这样我们就可以在 O(log n) 时间内找到索引为 n 的节点:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

...我们可能会发现一棵充满自然数的树很方便,因此我们不必摆弄这些索引:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

既然我们可以索引,你可以把树转换成列表:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

您可以通过验证toList nats 为您提供[0..] 来检查到目前为止的工作

现在,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

与上面的列表一样工作,但不是花费线性时间来查找每个节点,而是可以在对数时间内追踪它。

结果要快得多:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

事实上,它的速度要快得多,您可以通过上面的Int 替换Integer 并几乎立即获得大得离谱的答案

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

对于实现基于树的记忆的开箱即用库,请使用MemoTrie

$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

【讨论】:

  • 我试过这段代码,有趣的是,f_faster 似乎比 f 慢。我猜这些列表引用确实减慢了速度。 nats 和 index 的定义对我来说似乎很神秘,所以我添加了自己的答案,这可能会让事情更清楚。
  • 无限列表的情况要处理一个长111111111项的链表。树的情况是处理 log n * 到达的节点数。
  • 即列表版本必须为列表中的所有节点创建 thunk,而树版本避免创建大量节点。
  • 我知道这是一篇相当老的帖子,但不应该在 where 子句中定义 f_tree 以避免在调用树中保存不需要的路径吗?
  • 把它塞进 CAF 的原因是你可以在调用中获得记忆。如果我要记住一个昂贵的电话,那么我可能会把它留在 CAF 中,因此这里显示的技术。在实际应用中,当然需要在永久记忆的收益和成本之间进行权衡。虽然,考虑到问题是关于如何实现记忆化,我认为用一种故意避免跨调用记忆化的技术来回答会产生误导,如果没有别的,那么这里的评论将向人们指出存在微妙之处的事实。 ;)
【解决方案3】:

这是对 Edward Kmett 出色答案的补充。

当我尝试他的代码时,natsindex 的定义似乎很神秘,所以我写了一个我觉得更容易理解的替代版本。

我根据index'nats' 定义indexnats

index' t n[1..] 范围内定义。 (回想一下index t 是在[0..] 范围内定义的。)它通过将n 视为一串位并反向读取位来搜索树。如果该位是1,则它采用右手分支。如果该位是0,则它采用左侧分支。它在到达最后一位时停止(必须是1)。

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

正如nats 是为index 定义的,所以index nats n == n 始终为真,nats' 是为index' 定义的。

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

现在,natsindex 只是 nats'index',但值移动了 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'

【讨论】:

  • 谢谢。我正在记忆一个多元函数,这确实帮助我弄清楚了 index 和 nats 到底在做什么。
【解决方案4】:

Edward's answer 是一个非常棒的宝石,我复制了它并提供了 memoListmemoTree 组合子的实现,它们以开放递归的形式存储函数。

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f

【讨论】:

    【解决方案5】:

    Edward Kmett 回答的另一个附录:一个独立的例子:

    data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)
    
    memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
      where nats = go 0 1
            go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
              where s' = 2*s
            index (NatTrie l v r) i
              | i <  0    = f (index_to_arg i)
              | i == 0    = v
              | otherwise = case (i-1) `divMod` 2 of
                 (i',0) -> index l i'
                 (i',1) -> index r i'
    
    memoNat = memo1 id id 
    

    如下使用它来记忆具有单个整数 arg 的函数(例如斐波那契):

    fib = memoNat f
      where f 0 = 0
            f 1 = 1
            f n = fib (n-1) + fib (n-2)
    

    只有非负参数的值会被缓存。

    要同时缓存负参数的值,请使用memoInt,定义如下:

    memoInt = memo1 arg_to_index index_to_arg
      where arg_to_index n
             | n < 0     = -2*n
             | otherwise =  2*n + 1
            index_to_arg i = case i `divMod` 2 of
               (n,0) -> -n
               (n,1) ->  n
    

    要缓存具有两个整数参数的函数的值,请使用memoIntInt,定义如下:

    memoIntInt f = memoInt (\n -> memoInt (f n))
    

    【讨论】:

      【解决方案6】:

      正如 Edward Kmett 的回答中所述,为了加快速度,您需要缓存昂贵的计算并能够快速访问它们。

      为了保持函数非单子,构建无限惰性树的解决方案,用适当的方法来索引它(如以前的帖子所示)实现了这个目标。如果您放弃函数的非单子性质,您可以将 Haskell 中可用的标准关联容器与“类状态”单子(如 State 或 ST)结合使用。

      虽然主要缺点是您获得了一个非单子函数,但您不必再自己索引结构,并且可以使用关联容器的标准实现。

      为此,您首先需要重写您的函数以接受任何类型的 monad:

      fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
      fm _    0 = return 0
      fm recf n = do
         recs <- mapM recf $ div n <$> [2, 3, 4]
         return $ max n (sum recs)
      

      对于您的测试,您仍然可以使用 Data.Function.fix 定义一个不做记忆的函数,尽管它有点冗长:

      noMemoF :: (Integral n) => n -> n
      noMemoF = runIdentity . fix fm
      

      然后您可以结合使用 State monad 和 Data.Map 来加快速度:

      import qualified Data.Map.Strict as MS
      
      withMemoStMap :: (Integral n) => n -> n
      withMemoStMap n = evalState (fm recF n) MS.empty
         where
            recF i = do
               v <- MS.lookup i <$> get
               case v of
                  Just v' -> return v' 
                  Nothing -> do
                     v' <- fm recF i
                     modify $ MS.insert i v'
                     return v'
      

      只需稍作改动,您就可以调整代码以与 Data.HashMap 一起使用:

      import qualified Data.HashMap.Strict as HMS
      
      withMemoStHMap :: (Integral n, Hashable n) => n -> n
      withMemoStHMap n = evalState (fm recF n) HMS.empty
         where
            recF i = do
               v <- HMS.lookup i <$> get
               case v of
                  Just v' -> return v' 
                  Nothing -> do
                     v' <- fm recF i
                     modify $ HMS.insert i v'
                     return v'
      

      除了持久性数据结构,您还可以尝试将可变数据结构(如 Data.HashTable)与 ST monad 结合使用:

      import qualified Data.HashTable.ST.Linear as MHM
      
      withMemoMutMap :: (Integral n, Hashable n) => n -> n
      withMemoMutMap n = runST $
         do ht <- MHM.new
            recF ht n
         where
            recF ht i = do
               k <- MHM.lookup ht i
               case k of
                  Just k' -> return k'
                  Nothing -> do 
                     k' <- fm (recF ht) i
                     MHM.insert ht i k'
                     return k'
      

      与没有任何记忆的实现相比,这些实现中的任何一个都允许您在大量输入的情况下在几微秒内获得结果,而不必等待几秒钟。

      使用 Criterion 作为基准,我可以观察到,使用 Data.HashMap 的实现实际上比使用时间非常相似的 Data.Map 和 Data.HashTable 的性能略好(大约 20%)。

      我发现基准测试的结果有点令人惊讶。我最初的感觉是 HashTable 会优于 HashMap 实现,因为它是可变的。最后一个实现中可能隐藏了一些性能缺陷。

      【讨论】:

      • GHC 在围绕不可变结构进行优化方面做得非常好。来自 C 的直觉并不总是成功。
      【解决方案7】:

      几年后,我看到这个并意识到有一种简单的方法可以使用 zipWith 和一个辅助函数在线性时间内记住这个:

      dilate :: Int -> [x] -> [x]
      dilate n xs = replicate n =<< xs
      

      dilate 具有 dilate n xs !! i == xs !! div i n 的便利属性。

      所以,假设给定 f(0),这将计算简化为

      fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
        where (.+.) = zipWith (+)
              infixl 6 .+.
              (#/) = flip dilate
              infixl 7 #/
      

      看起来很像我们最初的问题描述,并给出一个线性解决方案(sum $ take n fs 需要 O(n))。

      【讨论】:

      • 所以它是一种生成(核心递归?)或动态编程解决方案。每个生成的值花费 O(1) 时间,就像通常的斐波那契一样。伟大的! EKMETT 的解决方案就像对数大斐波那契,更快地得到大数字,跳过很多中间值。这是对的吗?
      • 或者它可能更接近于汉明数,三个反向指针指向正在生成的序列,并且每个反向指针沿着它前进的速度不同。真的很漂亮。
      【解决方案8】:

      一种没有索引且不基于 Edward KMETT 的解决方案。

      我将公共子树分解为公共父级(f(n/4)f(n/2)f(n/4) 之间共享,f(n/6)f(2)f(3) 之间共享)。通过将它们保存为父变量中的单个变量,子树的计算就完成了一次。

      data Tree a =
        Node {datum :: a, child2 :: Tree a, child3 :: Tree a}
      
      f :: Int -> Int
      f n = datum root
        where root = f' n Nothing Nothing
      
      
      -- Pass in the arg
        -- and this node's lifted children (if any).
      f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
      f' 0 _ _ = leaf
          where leaf = Node 0 leaf leaf
      f' n m2 m3 = Node d c2 c3
        where
          d = if n < 12 then n
                  else max n (d2 + d3 + d4)
          [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
          [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
          c2 = case m2 of    -- Check for a passed-in subtree before recursing.
            Just c2' -> c2'
            Nothing -> f' n2 Nothing (Just c6)
          c3 = case m3 of
            Just c3' -> c3'
            Nothing -> f' n3 (Just c6) Nothing
          c4 = child2 c2
          c6 = f' n6 Nothing Nothing
      
          main =
            print (f 123801)
            -- Should print 248604.
      

      代码不容易扩展到一般的记忆功能(至少,我不知道该怎么做),你真的必须考虑子问题如何重叠,但是策略 应该适用于一般的多个非整数参数。 (我想到了两个字符串参数。)

      备忘录在每次计算后被丢弃。 (再次,我在考虑两个字符串参数。)

      我不知道这是否比其他答案更有效。从技术上讲,每次查找只需一两个步骤(“看看您的孩子或您孩子的孩子”),但可能会占用大量额外的内存。

      编辑:这个解决方案还不正确。分享不完整。

      编辑:它现在应该正确共享子子项,但我意识到这个问题有很多不平凡的共享:n/2/2/2n/3/3 可能是相同的。这个问题不适合我的策略。

      【讨论】:

        猜你喜欢
        • 2020-10-25
        • 2016-04-11
        • 2011-11-16
        • 2018-09-08
        • 1970-01-01
        • 2011-10-20
        • 1970-01-01
        • 2013-05-11
        • 2011-07-29
        相关资源
        最近更新 更多