【发布时间】:2014-02-23 10:41:15
【问题描述】:
我一直致力于加快粒子过滤器的重采样计算。由于 python 有很多方法可以加快速度,所以我会尝试所有方法。不幸的是,numba 版本非常慢。由于 Numba 应该会加快速度,因此我认为这是我的错误。
我尝试了 4 个不同的版本:
- Numba
- Python
- 麻木
- 赛通
每个代码如下:
import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample
@nb.autojit
def numba_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def python_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def numpy_resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = sp.cumsum(qs)
for j, key in enumerate(rands):
i = sp.argmax(lookup>key)
results[j] = xs[i]
return results
#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs,
np.ndarray[DTYPE_t, ndim=1] xs,
np.ndarray[DTYPE_t, ndim=1] rands):
if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
raise ValueError("Arrays must have same shape")
assert qs.dtype == xs.dtype == rands.dtype == DTYPE
cdef unsigned int n = qs.shape[0]
cdef unsigned int i, j
cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
"""
if __name__ == '__main__':
n = 100
xs = np.arange(n, dtype=np.float64)
qs = np.array([1.0/n,]*n)
rands = np.random.rand(n)
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Python Function:"
%timeit python_resample(qs, xs, rands)
print "Timing Numpy Function:"
%timeit numpy_resample(qs, xs, rands)
print "Timing Cython Function:"
%timeit cython_resample(qs, xs, rands)
这会产生以下输出:
Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop
知道为什么 numba 代码这么慢吗?我认为它至少可以与 Numpy 相媲美。
注意:如果有人对如何加速 Numpy 或 Cython 代码示例有任何想法,那也很好:)我的主要问题是关于 Numba。
【问题讨论】:
-
我认为更好的地方是codereview.stackexchange.com
-
用更大的列表试试?
-
@IanAuld:也许吧,但由于其他人已经从 numba 获得了显着的加速,我认为这是我用错了,而不仅仅是一个分析问题。在我看来,这符合 stackoverflow 的预期用途。
-
@JoranBeasley:我尝试了 1000 和 10000 分。 Numba 运行 1000 需要 773 毫秒,而纯 python 需要 234 毫秒。 10000点试炼还在进行中……
-
请注意
argmax可以接受一个轴参数,因此您可以相互广播rands和lookup以制作一个用于N^2 缩放算法的n x n矩阵。或者,您可以使用 searchsorted 将具有(应该有?)Nlog(N) 缩放。
标签: python performance numpy cython numba