Numba 的即时编译器应该通过在首次执行期间将函数编译为本机代码来很好地处理您面临的索引开销:
In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
:import numpy as np
:
:sig = np.random.randn(44100)
:alpha = .9887
:beta = .999
:
:def nonvectorized(sig):
: out = np.zeros_like(sig)
:
: for n in range(1, len(sig)):
: if np.abs(sig[n]) >= out[n-1]:
: out[n] = alpha * out[n-1] + (1 - alpha) * np.abs( sig[n] )
: else:
: out[n] = beta * out[n-1]
: return out
:--
In [2]: nonvectorized(sig)
Out[2]:
array([ 0. , 0.01862503, 0.04124917, ..., 1.2979579 ,
1.304247 , 1.30294275])
In [3]: %timeit nonvectorized(sig)
10 loops, best of 3: 80.2 ms per loop
In [4]: from numba import jit
In [5]: vectorized = jit(nonvectorized)
In [6]: np.allclose(vectorized(sig), nonvectorized(sig))
Out[6]: True
In [7]: %timeit vectorized(sig)
1000 loops, best of 3: 249 µs per loop
编辑:按照评论中的建议,添加 jit 基准。 jit(nonvectorized) 正在创建一个轻量级包装器,因此是一种廉价的操作。
In [8]: %timeit jit(nonvectorized)
10000 loops, best of 3: 45.3 µs per loop
函数本身是在第一次执行期间编译的(因此即时)这需要一段时间,但可能不会那么多:
In [9]: %timeit jit(nonvectorized)(sig)
10 loops, best of 3: 169 ms per loop