【问题标题】:Why numba don't improve the speed of my knapsack function?为什么 numba 不提高我的背包功能的速度?
【发布时间】:2020-03-19 22:43:50
【问题描述】:

我尝试使用 numba 加速我的代码,但它似乎不起作用。该程序与@jit@njit 或纯 python 的时间相同(大约 10 秒)。但是我使用的是 numpy 而不是 list 或 dict。

这是我的代码:

import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)

@njit
def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N =  np.full((n+1,W+1,W+1),0)
    M =  np.full((n+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

@profile
def main():

    size = 1000
    val = [random.randint(1, size) for i in range(0, size)]
    wt = [random.randint(1, size) for i in range(0, size)]
    W = 1000
    n = len(val)
    a = knapSack(W, wt, val, n)
main()

【问题讨论】:

    标签: python python-3.x performance jit numba


    【解决方案1】:

    事实上,如果不改变方法本身,可能无法真正提高当前算法的性能。

    您的 N 数组包含大约 10 亿个对象 (1001 * 1001 * 1001)。您需要设置每个元素,因此您至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10 亿次操作,每次 1 纳秒意味着需要 1 秒才能完成。正如我所说,每次操作可能需要超过 1 纳秒的时间,所以我们假设它需要 10 纳秒(可能有点高,但比 1 纳秒更现实),这意味着算法总共有 10 秒。

    因此,您输入的预期运行时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,那么它可能已经处于您选择的方法所能达到的极限,并且没有任何工具会(显着)改善该运行时间。


    使用np.zeros 代替np.full 可以让它更快一点:

    K = np.zeros((n+1, W+1), dtype=int)
    N = np.zeros((n+1, W+1, W+1), dtype=int)
    

    不要创建M,因为你不会使用它。


    由于您已经使用了 line-profiler,我决定看一看,结果如下:

    Line #      Hits         Time  Per Hit   % Time  Line Contents
    ==============================================================
         3                                           def knapSack(W, wt, val, n):
         4         1      19137.0  19137.0      0.0      K = np.full((n+1,W+1),0)
         5         1   19408592.0 19408592.0     28.1      N = np.full((n+1,W+1,W+1),0)
         6                                           
         7      1002       6412.0      6.4      0.0      for i in range(n+1):
         8   1003002    4186311.0      4.2      6.1          for w in range(W+1):
         9   1002001    4644031.0      4.6      6.7              if i==0 or w==0:
        10      2001      19663.0      9.8      0.0                  K[i][w] = 0
        11   1000000    5474080.0      5.5      7.9              elif wt[i-1] <= w:
        12    498365    9616406.0     19.3     13.9                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
        13     52596     902030.0     17.2      1.3                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
        14     52596     578740.0     11.0      0.8                      c = N[i-1][w-wt[i-1]]
        15     52596     295980.0      5.6      0.4                      c[i] = i
        16     52596    1239792.0     23.6      1.8                      N[i][w] = c
        17                                                           else:
        18    445769    5100917.0     11.4      7.4                      K[i][w] = K[i-1][w]
        19    445769   11677683.0     26.2     16.9                      N[i][w] = N[i-1][w]
        20                                                       else:
        21    501635    5801328.0     11.6      8.4                  K[i][w] = K[i-1][w]
        22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
        23         1         14.0     14.0      0.0      return N[n][W]
    

    这表明瓶颈是np.fullN[i][w] = N[i-1][w]if(val[i-1] + K[i-1][w-wt[i-1]] &gt; K[i-1][w])。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,对于这些,numba 更有可能更慢。 Numba 可能可以改进if(val[i-1] + K[i-1][w-wt[i-1]] &gt; K[i-1][w]),但这可能不会引起注意。

    如果np.fullnp.zeros 替换,则配置文件会略有变化:

    Line #      Hits         Time  Per Hit   % Time  Line Contents
    ==============================================================
         3                                           def knapSack(W, wt, val, n):
         4         1        747.0    747.0      0.0      K = np.zeros((n+1, W+1),dtype=int)
         5         1     109592.0 109592.0      0.2      N = np.zeros((n+1, W+1, W+1),dtype=int)
         6                                           
         7      1002       4230.0      4.2      0.0      for i in range(n+1):
         8   1003002    4414071.0      4.4      7.0          for w in range(W+1):
         9   1002001    4836807.0      4.8      7.7              if i==0 or w==0:
        10      2001      22282.0     11.1      0.0                  K[i][w] = 0
        11   1000000    5646859.0      5.6      8.9              elif wt[i-1] <= w:
        12    521222   10389581.0     19.9     16.5                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
        13     47579     784563.0     16.5      1.2                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
        14     47579     509056.0     10.7      0.8                      c = N[i-1][w-wt[i-1]]
        15     47579     362796.0      7.6      0.6                      c[i] = i
        16     47579    1975916.0     41.5      3.1                      N[i][w] = c
        17                                                           else:
        18    473643    5579823.0     11.8      8.8                      K[i][w] = K[i-1][w]
        19    473643   22805846.0     48.1     36.1                      N[i][w] = N[i-1][w]
        20                                                       else:
        21    478778    5664271.0     11.8      9.0                  K[i][w] = K[i-1][w]
        22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
        23         1         10.0     10.0      0.0      return N[n][W]
    

    但主要瓶颈仍然是N[i][w] = N[i-1][w],使用 numba 可能比使用纯 NumPy 慢。因此,使用 numba 对代码的其他一些部分进行的改进可能不会被注意到(再次)。


    对于第一个配置文件,我使用了这个版本的代码(第二个配置文件只是将 np.full 更改为 np.zeros):

    import numpy as np
    
    def knapSack(W, wt, val, n):
        K = np.full((n+1,W+1),0)
        N = np.full((n+1,W+1,W+1),0)
    
        for i in range(n+1):
            for w in range(W+1):
                if i==0 or w==0:
                    K[i][w] = 0
                elif wt[i-1] <= w:
                    if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                        K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                        c = N[i-1][w-wt[i-1]]
                        c[i] = i
                        N[i][w] = c
                    else:
                        K[i][w] = K[i-1][w]
                        N[i][w] = N[i-1][w]
                else:
                    K[i][w] = K[i-1][w]
        N[n][W][0] = K[n][W]
        return N[n][W]
    
    import random
    size = 1000
    val = [random.randint(1, size) for i in range(0, size)]
    wt = [random.randint(1, size) for i in range(0, size)]
    W = 1000
    n = len(val)
    
    %lprun -f knapSack knapSack(W, wt, val, n)
    

    【讨论】:

    • 您好,感谢您的回复。我试过了,是的,它更好。我也找到了一个很好的解决方案,我的 N 数组可以减少很多:
    • 感谢你和我的修改,现在函数比以前快100,比纯python快3-4
    • @blaudantoine 不客气。 N 的更改很有趣,我没有考虑过这一点,因为我认为您需要完整的 N,因为数组是由函数返回的。如果您觉得有帮助,请不要忘记upvote the answer(请参阅When should I vote up?)。
    【解决方案2】:

    这里是新功能:

     @njit
        def knapSack(W, wt, val, n):
    
            K = np.zeros((n + 1, W + 1),dtype=np.int32)
            # In fact we must only save the previous combinations and the current, 
            # not all :) So N is considerably reduce
            N = np.zeros((2, W + 1, W + 1),dtype=np.int32)
    
            for i in range(n + 1):
                for w in range(W + 1):
                    if i == 0 or w == 0:
                        K[i][w] = 0
                    elif wt[i - 1] <= w:
                        if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
                            K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
                            N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
                            N[i%2][w][i] = i
                        else:
                            K[i][w] = K[i - 1][w]
                            N[i%2][w] = N[(i - 1)%2][w]
                    else:
                        K[i][w] = K[i - 1][w]
            N[(n)%2][W][0] = K[n][W]
            return N[(n)%2][W]
    

    非常感谢 MSeifert !!

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-01-18
      • 1970-01-01
      • 2015-07-16
      • 1970-01-01
      • 1970-01-01
      • 2019-01-13
      相关资源
      最近更新 更多