【发布时间】:2021-02-20 18:32:08
【问题描述】:
我决定开始使用Numba 来加速我拥有的一些科学代码。为了让我的脚湿透,我在谷歌上搜索并找到了一些提供运行时的教程。我所指的教程示例来自here。他们计算一个 numpy 数组的标准差,然后使用 Numba 的 @njit 装饰器再次计算,然后比较运行时。这是我在我的计算机上重新创建此代码的尝试:
import math
import numpy as np
import numba
from timeit import timeit
def std(xs):
# compute the mean
mean = 0
for x in xs:
mean += x
mean /= len(xs)
# compute the variance
ms = 0
for x in xs:
ms += (x - mean) ** 2
variance = ms / len(xs)
std = math.sqrt(variance)
return std
c_std = numba.njit(std)
a = np.random.normal(0, 1, 10000)
现在这是我在时间方面得到的结果:
print(timeit('std(a)', globals=globals(), number=1) * 1000, 'ms')
在 9 毫秒内运行。但是,当我合并 Numba 时:
print(timeit('c_std(a)', globals=globals(), number=1) * 1000, 'ms')
运行时间为 375 毫秒。
这是大约 40 倍的减速。在本教程中,使用相同的代码,他们的 Numba-jitted 代码运行时间为 31.6 毫秒,而没有 Numba 的时间为 4600 毫秒,加速约 150 倍。我可以看到我们的代码之间的唯一区别是它们的正态分布样本来自 10,000,000 个点,而我的只有 10,000 个,但这是一个必要的调整,因为运行更多点需要很长时间。
我在 Python 3.8 上使用 Numpy 1.19.2 和 Numba 0.51.2。我在 MacOs 10.14 上运行 conda 4.9.2。
【问题讨论】:
-
你正在计时编译!
-
但这不就是他们在教程中所做的吗..?
-
我认为编译发生在这里:
c_std = numba.njit(std) -
从您正在阅读的教程中:“另外,请记住,第一次调用函数时,numba 需要编译函数,这需要一些时间。”他们对函数在调用一次以编译它进行计时。
-
我明白了。但是我多次运行我的 .py 文件。它不记得 *.pyc 文件中的编译吗?如果您有任何建议,请告诉我,我会看看我是否可以将编译部分提取出来。