【问题标题】:Leetcode 1155. dictionary based memoization vs LRU CacheLeetcode 1155. 基于字典的记忆与 LRU 缓存
【发布时间】:2021-10-16 16:08:46
【问题描述】:

我正在解决leetcode 1155,这是关于目标总和的骰子数量。我正在使用基于字典的记忆。这是确切的代码:

class Solution:
    def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
        
        dp = {}
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            if dp.get((t,rd)): return dp[(t,rd)]
            dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
            return dp[(t,rd)]
        
        return ways(target, dices)

但对于 15*15 左右的面和骰子组合,此解决方案总是超时

然后我发现了这个使用 functools.lru_cache 的解决方案,其余部分完全相同。此解决方案运行速度非常快。

class Solution:
    def numRollsToTarget(self, dices: int, faces: int, target: int) -> int:
        from functools import lru_cache
        @lru_cache(None)
        def ways(t, rd):
            if t == 0  and rd == 0: return 1
            if t <= 0 or rd <= 0: return 0
            return sum(ways(t-i, rd-1) for i in range(1,faces+1))
        
        return ways(target, dices)

之前,我比较过,发现在大多数情况下,lru_cache 的性能并没有比基于字典的缓存高出这么多。

有人能解释一下为什么这两种方法之间存在如此巨大的性能差异吗?

【问题讨论】:

  • 如果您将if dp.get((t,rd)): return dp[(t,rd)] 作为函数的第一行,那么您的代码将更类似于带有lru_cache 的版本。您不必检查 t == 0, .... 以获取 dp 中的值。而lru_cache 也不检查t == 0, ....
  • python 需要一些时间来访问字典中的元素,所以也许你应该将 sum 分配给局部变量 - result = sum(...) - 以及稍后的 dp[(t,rd)] = resultreturn result。这样您将只运行一次dp[(t,rd)]
  • 这些步骤可能只提供边际改进,但我尝试时仍然超时。
  • 您也可以先运行cProfile 以查看您的瓶颈并尝试对其进行优化...我将发布类似的内容以进行比较。
  • 一些cProfile结果分享哪个区域正在做heavy-lifting,可能会导致问题。然后又修改了memo的版本是profile来比较一下区别。

标签: python dynamic-programming memoization


【解决方案1】:

首先,使用cProfile 运行您的 OP 代码,这是报告:

  • with print(numRollsToTarget2(4, 6, 20))(OP 版本)

您可以立即发现 ways genexprsum 中有一些沉重的电话。那是概率。需要仔细检查并尝试改进/减少。下一个帖子是类似的memo 版本,但calls 要少得多。该版本已通过,没有超时。

35
         2864 function calls (366 primitive calls) in 0.018 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.018    0.018 <string>:1(<module>)
        1    0.000    0.000    0.001    0.001 dice_rolls.py:23(numRollsToTarget2)
   1075/1    0.001    0.000    0.001    0.001 dice_rolls.py:25(ways)
   1253/7    0.001    0.000    0.001    0.000 dice_rolls.py:30(<genexpr>)
        1    0.000    0.000    0.018    0.018 dice_rolls.py:36(main)
       21    0.000    0.000    0.000    0.000 rpc.py:153(debug)
        3    0.000    0.000    0.017    0.006 rpc.py:216(remotecall)
        3    0.000    0.000    0.000    0.000 rpc.py:226(asynccall)
        3    0.000    0.000    0.016    0.005 rpc.py:246(asyncreturn)
        3    0.000    0.000    0.000    0.000 rpc.py:252(decoderesponse)
        3    0.000    0.000    0.016    0.005 rpc.py:290(getresponse)
        3    0.000    0.000    0.000    0.000 rpc.py:298(_proxify)
        3    0.000    0.000    0.016    0.005 rpc.py:306(_getresponse)
        3    0.000    0.000    0.000    0.000 rpc.py:328(newseq)
        3    0.000    0.000    0.000    0.000 rpc.py:332(putmessage)
        2    0.000    0.000    0.001    0.000 rpc.py:559(__getattr__)
        3    0.000    0.000    0.000    0.000 rpc.py:57(dumps)
        1    0.000    0.000    0.001    0.001 rpc.py:577(__getmethods)
        2    0.000    0.000    0.000    0.000 rpc.py:601(__init__)
        2    0.000    0.000    0.016    0.008 rpc.py:606(__call__)
        4    0.000    0.000    0.000    0.000 run.py:412(encoding)
        4    0.000    0.000    0.000    0.000 run.py:416(errors)
        2    0.000    0.000    0.017    0.008 run.py:433(write)
        6    0.000    0.000    0.000    0.000 threading.py:1306(current_thread)
        3    0.000    0.000    0.000    0.000 threading.py:222(__init__)
        3    0.000    0.000    0.016    0.005 threading.py:270(wait)
        3    0.000    0.000    0.000    0.000 threading.py:81(RLock)
        3    0.000    0.000    0.000    0.000 {built-in method _struct.pack}
        3    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
        6    0.000    0.000    0.000    0.000 {built-in method _thread.get_ident}
        1    0.000    0.000    0.018    0.018 {built-in method builtins.exec}
        6    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        9    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        1    0.000    0.000    0.017    0.017 {built-in method builtins.print}
    179/1    0.000    0.000    0.001    0.001 {built-in method builtins.sum}
        3    0.000    0.000    0.000    0.000 {built-in method select.select}
        3    0.000    0.000    0.000    0.000 {method '_acquire_restore' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method '_is_owned' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method '_release_save' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.RLock' objects}
        6    0.016    0.003    0.016    0.003 {method 'acquire' of '_thread.lock' objects}
        3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        2    0.000    0.000    0.000    0.000 {method 'decode' of 'bytes' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        3    0.000    0.000    0.000    0.000 {method 'dump' of '_pickle.Pickler' objects}
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
      201    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        3    0.000    0.000    0.000    0.000 {method 'getvalue' of '_io.BytesIO' objects}
        3    0.000    0.000    0.000    0.000 {method 'release' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method 'send' of '_socket.socket' objects}

然后我尝试运行修改/简化版本,并比较结果。

35
         387 function calls (193 primitive calls) in 0.006 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.006    0.006 <string>:1(<module>)
        1    0.000    0.000    0.006    0.006 dice_rolls.py:36(main)
        1    0.000    0.000    0.000    0.000 dice_rolls.py:5(numRollsToTarget)
    195/1    0.000    0.000    0.000    0.000 dice_rolls.py:8(dp)
       21    0.000    0.000    0.000    0.000 rpc.py:153(debug)
        3    0.000    0.000    0.006    0.002 rpc.py:216(remotecall)
        3    0.000    0.000    0.000    0.000 rpc.py:226(asynccall)
        3    0.000    0.000    0.006    0.002 rpc.py:246(asyncreturn)
        3    0.000    0.000    0.000    0.000 rpc.py:252(decoderesponse)
        3    0.000    0.000    0.006    0.002 rpc.py:290(getresponse)
        3    0.000    0.000    0.000    0.000 rpc.py:298(_proxify)
        3    0.000    0.000    0.006    0.002 rpc.py:306(_getresponse)
        3    0.000    0.000    0.000    0.000 rpc.py:328(newseq)
        3    0.000    0.000    0.000    0.000 rpc.py:332(putmessage)
        2    0.000    0.000    0.001    0.000 rpc.py:559(__getattr__)
        3    0.000    0.000    0.000    0.000 rpc.py:57(dumps)
        1    0.000    0.000    0.001    0.001 rpc.py:577(__getmethods)
        2    0.000    0.000    0.000    0.000 rpc.py:601(__init__)
        2    0.000    0.000    0.005    0.003 rpc.py:606(__call__)
        4    0.000    0.000    0.000    0.000 run.py:412(encoding)
        4    0.000    0.000    0.000    0.000 run.py:416(errors)
        2    0.000    0.000    0.006    0.003 run.py:433(write)
        6    0.000    0.000    0.000    0.000 threading.py:1306(current_thread)
        3    0.000    0.000    0.000    0.000 threading.py:222(__init__)
        3    0.000    0.000    0.006    0.002 threading.py:270(wait)
        3    0.000    0.000    0.000    0.000 threading.py:81(RLock)
        3    0.000    0.000    0.000    0.000 {built-in method _struct.pack}
        3    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
        6    0.000    0.000    0.000    0.000 {built-in method _thread.get_ident}
        1    0.000    0.000    0.006    0.006 {built-in method builtins.exec}
        6    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
        9    0.000    0.000    0.000    0.000 {built-in method builtins.len}
       34    0.000    0.000    0.000    0.000 {built-in method builtins.max}
        1    0.000    0.000    0.006    0.006 {built-in method builtins.print}
        3    0.000    0.000    0.000    0.000 {built-in method select.select}
        3    0.000    0.000    0.000    0.000 {method '_acquire_restore' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method '_is_owned' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method '_release_save' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.RLock' objects}
        6    0.006    0.001    0.006    0.001 {method 'acquire' of '_thread.lock' objects}
        3    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        2    0.000    0.000    0.000    0.000 {method 'decode' of 'bytes' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        3    0.000    0.000    0.000    0.000 {method 'dump' of '_pickle.Pickler' objects}
        2    0.000    0.000    0.000    0.000 {method 'encode' of 'str' objects}
        2    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        3    0.000    0.000    0.000    0.000 {method 'getvalue' of '_io.BytesIO' objects}
        3    0.000    0.000    0.000    0.000 {method 'release' of '_thread.RLock' objects}
        3    0.000    0.000    0.000    0.000 {method 'send' of '_socket.socket' objects}

分析代码在这里:

import cProfile
from typing import List

def numRollsToTarget(d, f, target):
    memo = {}

    def dp(d, target):
        if d == 0:
            return 0 if target > 0 else 1
        if (d, target) in memo:
            return memo[(d, target)]

        result = 0
        
        for k in range(max(0, target-f), target):
            result += dp(d-1, k)
        memo[(d, target)] = result
        return result 
    
    return dp(d, target) % (10**9 + 7)
    
def numRollsToTarget2(dices: int, faces: int, target: int) -> int:
    dp = {}
    def ways(t, rd):
        if t == 0  and rd == 0: return 1
        if t <= 0 or rd <= 0: return 0
        if dp.get((t,rd)): return dp[(t,rd)]
        
        dp[(t,rd)] = sum(ways(t-i, rd-1) for i in range(1,faces+1))
        return dp[(t,rd)]
        
    return ways(target, dices)

def numRollsToTarget3(dices: int, faces: int, target: int) -> int:
    from functools import lru_cache
    @lru_cache(None)
    def ways(t, rd):
        if t == 0  and rd == 0: return 1
        if t <= 0 or rd <= 0: return 0
        return sum(ways(t-i, rd-1) for i in range(1,faces+1))
        
    return ways(target, dices)
def main():
    print(numRollsToTarget(4, 6, 20))
    #print(numRollsToTarget2(4, 6, 20))
    #print(numRollsToTarget3(4, 6, 20))  # not faster than first



if __name__ == '__main__':
    cProfile.run('main()')

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-11-24
    • 2014-12-29
    • 1970-01-01
    • 2021-11-29
    • 2016-06-27
    • 1970-01-01
    • 2016-05-19
    相关资源
    最近更新 更多