【发布时间】:2018-05-26 01:23:54
【问题描述】:
我对使用 Numba 有点陌生,但我明白了它的要点。我想知道是否有任何更高级的技巧可以使四个嵌套的for 循环比我现在拥有的更快。特别是,我需要计算以下积分:
其中 B 是一个二维数组,S0 和 E 是某些参数。我的代码如下:
import numpy as np
from numba import njit, double
def calc_gb_gauss_2d(b,s0,e,dx):
n,m=b.shape
norm = 1.0/(2*np.pi*s0**2)
gb = np.zeros((n,m))
for i in range(n):
for j in range(m):
for ii in range(n):
for jj in range(m):
gb[i,j]+=np.exp(-(((i-ii)*dx)**2+((j-jj)*dx)**2)/(2.0*(s0*(1.0+e*b[i,j]))**2))
gb[i,j]*=norm
return gb
calc_gb_gauss_2d_nb = njit(double[:, :](double[:, :],double,double,double))(calc_gb_gauss_2d)
对于大小为256x256的输入数组,计算速度为:
In [4]: a=random.random((256,256))
In [5]: %timeit calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)
The slowest run took 8.46 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 1min 1s per loop
纯 Python 和 Numba 计算速度的比较给我这张图:
有什么方法可以优化我的代码以获得更好的性能?
【问题讨论】:
-
在
j循环中计算(2.0*(s0*(1.0+e*b[i,j]))**2),而不是最里面的循环。 -
另外,您的问题更适合code review,因为您的代码有效并且您正在寻找改进的方法。
-
非常感谢..所以我应该从这里删除这个问题并将其移至代码审查吗?
-
我想说,看看 CodeReview 上有多少关于 numba 的问题。我认为你在这里有更好的机会......
-
1) 不要使用显式类型声明。 (您不能明确声明输入数组在内存中是连续的,这是 SIMD 向量化所必需的)。看看numba.pydata.org/numba-doc/dev/user/performance-tips.html(fastmath=True 关键字和使用英特尔 SVML 可以在性能上产生相当大的差异)。同样使用最新的 Numba 版本,最近对并行函数的性能进行了一些优化。
标签: python loops iteration jit numba