【问题标题】:Memoizing a function of type [Integer] -> a记忆 [Integer] -> a 类型的函数
【发布时间】:2015-03-05 14:22:04
【问题描述】:

我的问题是如何有效地记忆一个昂贵的函数f :: [Integer] -> a,它是为所有有限整数列表定义并具有属性f . sort = f

我的典型用例是给定一个整数列表as,我需要获取各种整数 a 的值f (a:as),所以我想同时构建一个有向标记图,其顶点是一对整数列表及其函数值。一条从 (as, f as) 到 (bs, f bs) 的边当且仅当 a:as = bs 时才存在。

brilliant answer by Edward Kmett 中窃取我只是复制

{-# LANGUAGE BangPatterns #-}
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

并将他的想法应用于我的问题

-- directed graph labelled by Integers
data Graph a = Graph a (Tree (Graph a))
instance Functor Graph where
  fmap f (Graph a t) = Graph (f a) (fmap (fmap f) t)

-- walk the graph following the given labels
walk :: Graph a -> [Integer] -> a
walk (Graph a _) [] = a
walk (Graph _ t) (x:xs) = walk (index t x) xs

-- graph of all finite integer sequences
intSeq :: Graph [Integer]
intSeq = Graph [] (fmap (\n -> fmap (n:) intSeq) nats)

-- could be replaced by Data.Strict.Pair
data StrictPair a b = StrictPair !a !b
  deriving Show

-- f = sum modified according to Edward's idea (the real function is more complicated)
g :: ([Integer] -> StrictPair Integer [Integer]) -> [Integer] -> StrictPair Integer [Integer]
g mf [] = StrictPair 0 []
g mf (a:as) = StrictPair (a+x) (a:as)
  where StrictPair x y = mf as

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) intSeq

g_m :: [Integer] -> StrictPair Integer [Integer]
g_m = walk g_graph

这工作正常,但由于函数f 独立于出现的整数的顺序(但不依赖于它们的计数),因此对于所有等于排序的整数列表,图中应该只有一个顶点。

我如何做到这一点?

【问题讨论】:

  • 记忆是一回事...en.wikipedia.org/wiki/Memoization请不要将其编辑为“记忆”
  • 列表中的所有元素都是非负数吗?
  • @Cirdec:是的! (抱歉忘记提了。)

标签: haskell memoization


【解决方案1】:

只定义g_m' = g_m . sort 怎么样,即您只需在调用记忆函数之前先对输入列表进行排序?

我觉得这是你能做的最好的事情,因为如果你希望你的记忆图只包含排序的路径,那么在构建路径之前,某人将不得不查看列表的所有元素。

根据您的输入列表的外观,以减少树分支的方式对其进行转换可能会有所帮助。例如,您可以尝试对差异进行排序和取舍:

original input list:   [8,3,14,8,5]
sorted:                [3,3,8,8,14]
diffed:                [3,0,5,0,6] -- use this as the key

变换是双射的,因为涉及的数字较少,所以树的分支较少。

【讨论】:

  • 我想避免从根部遍历许多整数 a 的固定整数列表 as 以获得 f (a:as) 因为列表可能会变得很长。
  • 所以您似乎需要两件事:1)在计算 f([2..100]) 之后,您将计算 f([1..100]) 但您不想再次遍历图表以检索 @987654325 的缓存值@。并且 2) 在计算 f([1,5]) 之后,您希望 f([5,1]) 找到先前计算的值。对吗?
  • 是的!也许我应该把我的问题的第二段放在最后??
【解决方案2】:

您可以使用一些不同的方法。 有一个技巧可以证明可数集的有限乘积是可数的:

我们可以通过product . zipWith (^) primes2 ^ a1 * 3 ^ a2 * 5 ^ a3 * ... * primen ^ an[a1, ..., an]序列映射到Nat

为了避免结尾为零的序列出现问题,我们可以增加最后一个索引。

由于序列是有序的,我们可以利用user5402 提到的属性。

使用树的好处是可以增加分支来加速遍历。 OTOH 的主要技巧可以使索引变得相当大,但希望某些树路径将未被探索(保留为 thunk)。

{-# LANGUAGE BangPatterns #-}

-- Modified from Kmett's answer:
data Tree a = Tree a (Tree a) (Tree a) (Tree a) (Tree a)
instance Functor Tree where
  fmap f (Tree x a b c d) = Tree (f x) (fmap f a) (fmap f b) (fmap f c) (fmap f d)

index :: Tree a -> Integer -> a
index (Tree x _ _ _ _) 0 = x
index (Tree _ a b c d) n = case (n - 1) `divMod` 4 of
  (q,0) -> index a q
  (q,1) -> index b q
  (q,2) -> index c q
  (q,3) -> index d q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree n (go a s') (go b s') (go c s') (go d s')
            where
                a = n + s
                b = a + s
                c = b + s
                d = c + s
                s' = s * 4

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- Primes -- https://www.haskell.org/haskellwiki/Prime_numbers
-- Generation and factorisation could be done much better
minus (x:xs) (y:ys) = case (compare x y) of
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys
           GT ->     minus (x:xs)  ys
minus  xs     _     = xs

primes = 2 : sieve [3..] primes
  where
    sieve xs (p:ps) | q <- p*p , (h,t) <- span (< q) xs =
                   h ++ sieve (t `minus` [q, q+p..]) ps

addToLast :: [Integer] -> [Integer]
addToLast [] = []
addToLast [x] = [x + 1]
addToLast (x:xs) = x : addToLast xs

subFromLast :: [Integer] -> [Integer]
subFromLast [] = []
subFromLast [x] = [x - 1]
subFromLast (x:xs) = x : subFromLast xs

addSubProp :: [NonNegative Integer] -> Property
addSubProp xs = xs' === subFromLast (addToLast xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

listToInteger :: [Integer] -> Integer
listToInteger = product . zipWith (^) primes . addToLast

integerToList :: Integer -> [Integer]
integerToList = subFromLast . impl primes 0
  where impl _      _ 0 = []
        impl _      0 1 = []
        impl _      k 1 = [k]
        impl (p:ps) k n = case n `divMod` p of
                            (n', 0) -> impl (p:ps) (k + 1) n'
                            (_,  _) -> k : impl ps 0 n

listProp :: [NonNegative Integer] -> Property
listProp xs = xs' === integerToList (listToInteger xs')
  where xs' = map getNonNegative xs

toIndex :: [Integer] -> Integer
toIndex = listToInteger . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerToList

-- [1,0] /= [0]
-- Decreasing sequence!
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

在我的机器上fix g [1..22] 感觉很慢,而faster_g [1..40] 仍然非常快。


加法如果我们有界集(索引0..n-1),我们可以将其编码为:@987654329 @。

我们可以将任何Integer 编码为二进制列表,例如11[1, 1, 0, 1](最小位在前)。 然后,如果我们用2 分隔列表中的整数,我们会得到有界值序列。

作为奖励,我们可以使用 0、1、2 位的序列并将其 压缩 为二进制,例如使用霍夫曼编码,因为 2 比 0 或 1 少得多。但这可能有点矫枉过正。

有了这个技巧,索引会保持更小,空间可能会被更好地压缩。

{-# LANGUAGE BangPatterns #-}

-- From Kment's answer:
import Data.Function (fix)
import Data.List (sort, tails)
import Data.List.Split (splitOn)
import Test.QuickCheck

{-- Tree definition as before --}

-- 0, 1, 2
newtype N3 = N3 { unN3 :: Integer }
  deriving (Eq, Show)

instance Arbitrary N3 where
  arbitrary = elements $ map N3 [ 0, 1, 2 ]

-- Integer <-> N3
coeffs3 :: [Integer]
coeffs3 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 3)

listToInteger :: [N3] -> Integer
listToInteger = sum . zipWith f coeffs3
  where f n (N3 m) = n * m

listFromInteger :: Integer -> [N3]
listFromInteger 0 = []
listFromInteger n = case n `divMod` 3 of
  (q, m) -> N3 m : listFromInteger q

listProp :: [N3] -> Property
listProp xs = (null xs || last xs /= N3 0) ==> xs === listFromInteger (listToInteger xs)

-- Integer <-> N2

-- 0, 1
newtype N2 = N2 { unN2 :: Integer }
  deriving (Eq, Show)

coeffs2 :: [Integer]
coeffs2 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 2)

integerToBin :: Integer -> [N2]
integerToBin 0 = []
integerToBin n = case n `divMod` 2 of
  (q, m) -> N2 m : integerToBin q

integerFromBin :: [N2] -> Integer
integerFromBin = sum . zipWith f coeffs2
  where f n (N2 m) = n * m

binProp :: NonNegative Integer -> Property
binProp (NonNegative n) = n === integerFromBin (integerToBin n)

-- unsafe!
n3ton2 :: N3 -> N2
n3ton2 = N2 . unN3

n2ton3 :: N2 -> N3
n2ton3 = N3 . unN2

-- [Integer] <-> [N3]
integerListToN3List :: [Integer] -> [N3]
integerListToN3List = concatMap (++ [N3 2]) . map (map n2ton3 . integerToBin)

integerListFromN3List :: [N3] -> [Integer]
integerListFromN3List = init . map (integerFromBin . map n3ton2) . splitOn [N3 2]

n3ListProp :: [NonNegative Integer] -> Property
n3ListProp xs = xs' === integerListFromN3List (integerListToN3List xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
-- Integer <-> Sorted Integer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

---

toIndex :: [Integer] -> Integer
toIndex = listToInteger . integerListToN3List . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerListFromN3List . listFromInteger

-- [1,0] /= [0]
-- Decreasing sequence! doesn't terminate in this case
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

第二次加法:

我快速为我的g 对图形和二进制序列方法进行了基准测试:

main :: IO ()
main = do
  n <- read . head <$> getArgs
  print $ faster_g [100, 110..n]

结果是:

% time ./IntegerMemo 1000
1225560638892526472150132981770
./IntegerMemo 1000  0.19s user 0.01s system 98% cpu 0.200 total
% time ./IntegerMemo 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemo 2000  1.83s user 0.05s system 99% cpu 1.888 total
% time ./IntegerMemo 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemo 2500  3.74s user 0.09s system 99% cpu 3.852 total
% time ./IntegerMemo 3000    
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemo 3000  6.66s user 0.13s system 99% cpu 6.830 total

% time ./IntegerMemoGrap 1000 
1225560638892526472150132981770
./IntegerMemoGrap 1000  0.10s user 0.01s system 97% cpu 0.113 total
% time ./IntegerMemoGrap 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemoGrap 2000  0.97s user 0.04s system 98% cpu 1.028 total
% time ./IntegerMemoGrap 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemoGrap 2500  2.11s user 0.08s system 99% cpu 2.202 total
% time ./IntegerMemoGrap 3000 
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemoGrap 3000  3.33s user 0.09s system 99% cpu 3.452 total

看起来该图形版本的速度要快 2 的常数因子。但是它们似乎具有相同的时间复杂度:)

【讨论】:

  • 您的意思是改写 2^a1*3^a2*5^a3... 吗?我的列表将有几千个条目,其值高达几千,所以我不想将整数列表编码为(相当长的)整数。 PS:感谢您的长回答,检查可能需要一天...
  • 是的,primen^an。如果条目之间的增量相对较小,则结果整数不会“变大”。 faster_g [100,120..2000] 的执行速度仍然相对较快。 如果你知道值也是从顶部开始的,你可以使用a1 + a2 * N + a3 * N * N + ...编码,它更简洁
  • @j.p.似乎您的Graph 版本是我的二进制技巧的两倍。不是说Graph一个更瘦!
【解决方案3】:

看来我的问题只需将g_graph 定义中的intSeq 替换为单调版本即可解决:

-- replace vertexes for non-monotone integer lists by the according monotone one
monoIntSeq :: Graph [Integer]
monoIntSeq = f intSeq
  where f (Graph as t) | as == sort as = Graph as $ fmap f t
                       | otherwise     = fetch monIntSeq $ sort as

-- extract the subgraph after following the given labels
fetch :: Graph a -> [Integer] -> Graph a
fetch g [] = g
fetch (Graph _ t) (x:xs) = fetch (index t x) xs

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) monoIntSeq

非常感谢大家(尤其是 user5402 和 Oleg)的帮助!


编辑:我的典型用例仍然存在内存消耗过高的问题,可以通过以下路径来描述:

p :: [Integer]
p = map f [1..]
  where f n | n `mod` 6 == 0 = n `div` 6
            | n `mod` 3 == 0 = n `div` 3
            | n `mod` 2 == 0 = n `div` 2
            | otherwise      = n

一个小小的改进是像这样直接定义单调整数序列:

-- extract the subgraph after following the given labels (right to left)
fetch :: Graph a -> [Integer] -> Graph a
fetch = foldl' step
  where step (Graph _ t) n = index t n

-- walk the graph following the given labels (right to left)
walk :: Graph a -> [Integer] -> a
walk g ns = a
  where Graph a _ = fetch g ns

-- all monotone falling integer sequences
monoIntSeqs :: Graph [Integer]
monoIntSeqs = Graph [] $ fmap (flip f monoIntSeqs) nats
  where f n (Graph ns t) | null ns      = Graph (n:ns) $ fmap (f n) t
                         | n >= head ns = Graph (n:ns) $ fmap (f n) t
                         | otherwise    = fetch monoIntSeqs (insert' n ns)
        insert' = insertBy (comparing Down)

但最后我可能只使用原始整数序列而不进行标识,不时显式标识节点,并避免保留对 g_graph 等的引用,以便在程序进行时让垃圾收集清理。

【讨论】:

    【解决方案4】:

    阅读 Richard Bird 和 Ralf Hinze 的功能性珍珠 Trouble Shared is Trouble Halved,我了解了如何实现,两年前我一直在寻找什么(再次基于 Edward Kmett 的技巧):

    {-# LANGUAGE BangPatterns #-}
    import Data.Function (fix)
    
    data Tree a = Tree (Tree a) a (Tree a)
      deriving Show
    
    instance Functor Tree where
      fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
    
    index :: Tree a -> Integer -> a
    index (Tree _ m _) 0 = m
    index (Tree l _ r) n = case (n - 1) `divMod` 2 of
      (q,0) -> index l q
      (q,1) -> index r q
    
    nats :: Tree Integer
    nats = go 0 1
      where go !n !s = Tree (go l s') n (go r s')
              where l = n + s
                    r = l + s
                    s' = s * 2
    
    data IntSeqTree a = IntSeqTree a (Tree (IntSeqTree a))
    
    val :: IntSeqTree a -> a
    val (IntSeqTree a _) = a
    
    step :: Integer -> IntSeqTree t -> IntSeqTree t
    step n (IntSeqTree _ ts) = index ts n
    
    intSeqTree :: IntSeqTree [Integer]
    intSeqTree = fix $ create []
      where create p x = IntSeqTree p $ fmap (extend x) nats
            extend x n = case span (>n) (val x) of
                           ([], p) -> fix $ create (n:p)
                           (m, p)  -> foldr step intSeqTree (m ++ n:p)
    
    instance Functor IntSeqTree where
      fmap f (IntSeqTree a t) = IntSeqTree (f a) (fmap (fmap f) t)
    

    在我的用例中,我有成百上千个增量生成的类似整数序列(长度为几百个条目)。所以对我来说,这种方式比在查找函数值之前对序列进行排序要便宜(我将通过在 intSeqTree 上使用 fmap 来访问)。

    【讨论】:

      猜你喜欢
      • 2012-12-22
      • 2011-12-13
      • 2013-12-31
      • 1970-01-01
      • 1970-01-01
      • 2014-06-02
      • 2011-12-19
      • 2020-11-07
      • 2022-12-17
      相关资源
      最近更新 更多