【问题标题】:Why numba is slower than pure python in my code?为什么在我的代码中 numba 比纯 python 慢?
【发布时间】:2021-03-10 12:35:45
【问题描述】:

我对 python 有点陌生,我在玩 numba 并编写了一个在 numba 中运行速度比纯 python 慢的代码。在少数情况下,纯 python 比 numba 快大约 x4 倍,在大量情况下,它们的运行几乎相同。是什么让我的代码在 numba 中运行缓慢?

from numba import njit
@njit
def forr (q):
    p=0
    k=q
    n=0
    while k!=0:
            n += 1
            k=k//10
    
    h=(abs(q-n*9)+q-n*9)//2 
    for j in range(q,h,-1):
        
        s=0
        k=j
        while k!=0:
            s += k%10
            k=k//10
        
        if s+j==q:
            p=1
            print('Yes')
            break
    if p==0:
        print('No')

【问题讨论】:

  • JIT 需要时间。所以,如果你只对少数几个值执行函数,你就不会弥补 JIT 时间。
  • 我看到一个循环中的打印语句
  • 您可以尝试使用@njit(parallel = True) 看看是否有任何改进。此外,正如 Tarik 评论的那样,我认为这些打印语句不适合 Numba。
  • @Anthraxff 正如我所测量的in my answer,如果测量正确,您的numba 代码真的快29x 倍!请阅读my answer
  • 你如何测量时间,输入是什么以及它如何依赖于函数的代码,即如果你让它不那么复杂,问题仍然存在吗?

标签: python performance numba processing-efficiency


【解决方案1】:

我认为您的 Numba 代码运行速度较慢的原因是因为接下来的事情:

  1. 您可能会在第一次 Numba JIT 编译代码时测量函数第一次运行的时间,这可能需要几秒钟。要获得正确的时间测量,您需要先单独调用 numba 函数以便对其进行 JIT 预编译。
  2. 您可能没有提供足够大的输入(输入数字),因此您的函数需要很少的时间,并且 numba 函数有一些开销来启动。如果可能的话,您应该在您的代码中将耗时较长的算法放入 Numba 函数中,至少需要几十毫秒才能运行。
  3. 您可能只测量几次运行,您必须在一个循环中测量数百次函数运行才能获得更准确的结果。
  4. 您没有将cache = True 选项放入@njit 装饰器中,此选项将有助于在每个脚本运行时获取预编译代码,而不是从头开始编译。
  5. 打印函数在函数内部调用本身花费很少时间的函数可能会占用相当多的时间,因为控制台操作很慢。最好从函数返回结果并在 Numba 函数之外打印。

考虑到上面所有的事情,我实现了下一个代码来测量你的 Numba 代码,我只是添加了cache = True 选项并注释掉了print() 测量时间的调用(测量时不要用数百个单词破坏控制台)。

下一个代码显示 Numba 变体在我的笔记本电脑上快了29x 倍。下一个代码还需要通过命令pip install numba timerit安装一次pip模块。

Try it online!

import timerit, numba
timerit.Timerit._default_asciimode = True

def forr(q):
    p=0
    k=q
    n=0
    while k!=0:
            n += 1
            k=k//10
    
    h=(abs(q-n*9)+q-n*9)//2 
    for j in range(q,h,-1):
        
        s=0
        k=j
        while k!=0:
            s += k%10
            k=k//10
        
        if s+j==q:
            p=1
            #print('Yes')
            break
    if p==0:
        #print('No')
        pass
        
nforr = numba.njit(cache = True)(forr)
nforr(2) # Heat-up, precompile numba

tb = None
for f in [forr, nforr]:
    tim = timerit.Timerit(num = 99, verbose = 1)
    for t in tim:
        f(1 << 60)
    if tb is None:
        tb = tim.mean()
    else:
        print(f'speedup {round(tb / tim.mean(), 1)}x')

输出:

Timed best=1.029 ms, mean=1.040 +- 0.0 ms
Timed best=35.300 us, mean=35.673 +- 0.3 us
speedup 29.2x

【讨论】:

  • 如回答打印需要一些时间。最好的方法是从函数返回值并一次打印它们。
猜你喜欢
  • 2014-02-23
  • 2017-12-22
  • 1970-01-01
  • 2018-11-12
  • 1970-01-01
  • 2021-09-23
  • 1970-01-01
  • 1970-01-01
  • 2011-03-14
相关资源
最近更新 更多