你可以比现有的两个嵌套循环做得更好,一环一环 -
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
for i in range(r):
out1[:,i,:] = A[:, :, i] @ B[:, i,:]
或者,np.matmul/@ operator -
out = (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
这两个的扩展性似乎比einsum 版本好得多。
时间
案例 #1:缩放 1/4 尺寸
In [44]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = m//4, n//4, r//4
...:
...: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [45]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
175 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [46]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
165 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [47]: %timeit np.einsum('ijk,jkl->ikl', A, B, optimize=True)
483 ms ± 13.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
随着我们扩大规模,内存拥塞将开始有利于单循环版本。
案例 #2:缩放 1/2 尺寸
In [48]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = m//2, n//2, r//2
...:
...: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [49]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
2.9 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [50]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
3.02 s ± 94.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
案例 #3:缩放 67% 的尺寸
In [59]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
In [60]: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [61]: %%timeit
...: out1 = np.empty((m,r,n), dtype=np.result_type(A.dtype, B.dtype))
...: for i in range(r):
...: out1[:,i,:] = A[:, :, i] @ B[:, i,:]
25.8 s ± 4.9 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [62]: %timeit (A.transpose(2,0,1) @ B.transpose(1,0,2)).swapaxes(0,1)
29.2 s ± 2.41 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numba 衍生产品
from numba import njit, prange
@njit(parallel=True)
def func1(A, B):
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out = np.empty((m,r,n))
for i in prange(r):
out[:,i,:] = A[:, :, i] @ B[:, i,:]
return out
案例 #3 的时间安排 -
In [80]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = int(m/1.5), int(n/1.5), int(r/1.5)
In [81]: A = norm.rvs(size = (m, m, r), random_state = 0)
...: B = norm.rvs(size = (m, r, n), random_state = 0)
In [82]: %timeit func1(A, B)
653 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)