In [749]: x = np.random.choice(100, size=(23, 10, 3))
...: a = x[:, :, np.newaxis, :]
...: b = x[:, np.newaxis, :, :]
...: y = np.sum(a * b, axis=3)
In [750]: a.shape
Out[750]: (23, 10, 1, 3) # a view, no extra memory
In [751]: b.shape
Out[751]: (23, 1, 10, 3)
In [752]: y.shape
Out[752]: (23, 10, 10)
In [753]: (a*b).shape
Out[753]: (23, 10, 10, 3) # 3x larger than y
我不知道你怎么数了 9 次。
这也可以用einsum表示:
In [758]: np.einsum('ijl,ikl->ijk', x, x).shape
Out[758]: (23, 10, 10)
In [759]: np.allclose(np.einsum('ijl,ikl->ijk', x, x),y)
Out[759]: True
我不确定它的内存使用情况如何。在原始形式中,它迭代了一个 'ijkl' 空间。
快一点:
In [760]: timeit np.einsum('ijl,ikl->ijk', x, x).shape
74.1 µs ± 256 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [761]: timeit y = np.sum(a * b, axis=3)
90.9 µs ± 86.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
这需要更多的工作,但我找到了一种更快的方法,使用matmul:
In [771]: (a@b.transpose(0,1,3,2)).shape
Out[771]: (23, 10, 1, 10)
In [772]: np.allclose((a@b.transpose(0,1,3,2)).squeeze(),y)
Out[772]: True
In [773]: timeit (a@b.transpose(0,1,3,2)).shape
20 µs ± 28 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
这会将更多工作转移到快速编译的库中。我不能说内存使用。
在重复的like中找到的更简单的解决方案要快一点:
In [777]: timeit (x@x.transpose(0,2,1)).shape
18.4 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)