在这些情况下,最好考虑numba,它可以提供两全其美的效果:
import numpy as np
from numba import jit
def vanilla_mult(R, S):
m, n = R.shape[0], S.shape[1]
result = np.empty((m, n), dtype=R.dtype)
for i in range(m):
for j in range(n):
result[i, j] = np.dot(R[i, :], S[i, j,:])
return result
def broadcast_mult(R, S):
return np.sum(R[:, np.newaxis, :] * S, axis=2)
@jit(nopython=True)
def jit_mult(R, S):
m, n = R.shape[0], S.shape[1]
result = np.empty((m, n), dtype=R.dtype)
for i in range(m):
for j in range(n):
result[i, j] = np.dot(R[i, :], S[i, j,:])
return result
注意,vanilla_mult 和 jit_mult 具有完全相同的实现,但是后者是即时编译的。让我们测试一下:
In [1]: import test # the above is in test.py
In [2]: import numpy as np
In [3]: m, n, d = 100, 100, 100
In [4]: R = np.random.rand(m, d)
In [5]: S = np.random.rand(m, n, d)
好的...
In [6]: %timeit test.broadcast_mult(R, S)
100 loops, best of 3: 1.95 ms per loop
In [7]: %timeit test.vanilla_mult(R, S)
100 loops, best of 3: 11.7 ms per loop
哎呀,与广播相比,计算时间增加了近 5 倍。不过……
In [8]: %timeit test.jit_mult(R, S)
The slowest run took 760.57 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 870 µs per loop
不错!我们可以通过简单的 JITing 将运行时间减半!这个规模如何?
In [12]: m, n, d = 1000, 1000, 100
In [13]: R = np.random.rand(m, d)
In [14]: S = np.random.rand(m, n, d)
In [15]: %timeit test.vanilla_mult(R, S)
1 loop, best of 3: 1.22 s per loop
In [16]: %timeit test.broadcast_mult(R, S)
1 loop, best of 3: 666 ms per loop
In [17]: %timeit test.jit_mult(R, S)
The slowest run took 7.59 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 83.6 ms per loop
扩展性非常好,因为广播开始因必须创建大型中间数组而受到阻碍,与普通方法相比,它的时间只有一半,但几乎是 7 倍与 JIT 方法一样多!
编辑添加
最后,我们比较np.einsum 方法:
In [19]: %timeit np.einsum('md,mnd->mn', R, S)
10 loops, best of 3: 59.5 ms per loop
而且它显然是速度的赢家。不过,我对它还不够熟悉,无法评论空间要求。