【发布时间】:2019-04-09 18:43:36
【问题描述】:
在遍历 NumPy 数组时,Numba 似乎比 Cython 快得多。
我可能缺少哪些 Cython 优化?
这是一个简单的例子:
纯 Python 代码:
import numpy as np
def f(arr):
res=np.zeros(len(arr))
for i in range(len(arr)):
res[i]=(arr[i])**2
return res
arr=np.random.rand(10000)
%timeit f(arr)
输出:每个循环 4.81 ms ± 72.2 µs(平均值 ± 标准偏差,7 次运行,每次 100 个循环)
Cython 代码(在 Jupyter 中):
%load_ext cython
%%cython
import numpy as np
cimport numpy as np
cimport cython
from libc.math cimport pow
#@cython.boundscheck(False)
#@cython.wraparound(False)
cpdef f(double[:] arr):
cdef np.ndarray[dtype=np.double_t, ndim=1] res
res=np.zeros(len(arr),dtype=np.double)
cdef double[:] res_view=res
cdef int i
for i in range(len(arr)):
res_view[i]=pow(arr[i],2)
return res
arr=np.random.rand(10000)
%timeit f(arr)
输出:每个循环 445 µs ± 5.49 µs(平均值 ± 标准偏差,7 次运行,每次 1000 个循环)
Numba 代码:
import numpy as np
import numba as nb
@nb.jit(nb.float64[:](nb.float64[:]))
def f(arr):
res=np.zeros(len(arr))
for i in range(len(arr)):
res[i]=(arr[i])**2
return res
arr=np.random.rand(10000)
%timeit f(arr)
输出:每个循环 9.59 µs ± 98.8 ns(平均值 ± 标准偏差,7 次运行,每次 100000 次循环)
在此示例中,Numba 的速度几乎是 Cython 的 50 倍。
作为一个 Cython 初学者,我想我错过了一些东西。
当然,在这种简单的情况下,使用 NumPy square 向量化函数会更合适:
%timeit np.square(arr)
输出:每个循环 5.75 µs ± 78.9 ns(平均值 ± 标准偏差,7 次运行,每次 100000 次循环)
【问题讨论】:
-
你为什么不在 cython 代码中也做 arr[i]**2 ?我认为一个可能的原因是
pow(arr[i],2)会将2视为浮点数并使计算更加复杂 -
谢谢,但我也尝试过使用 arr[i]**2 而不是 pow(arr[i],2) ,两种解决方案的性能几乎相同。一般来说,即使在没有数学转换的情况下对 numpy 数组进行简单迭代,numba 编译函数的运行速度也比 cython 快。