【发布时间】:2020-08-06 22:59:55
【问题描述】:
我正在尝试将三个数组 (A x B x A) 与维度 (19000, 3) x (19000, 3, 3) x (19000, 3) 相乘,这样最后我得到一个大小为 (19000) 的一维数组,所以我只想沿最后一维/二维相乘。
我已经让它与 np.einsum() 一起工作,但我想知道是否有任何方法可以加快速度,因为这是我整个代码的瓶颈。
np.einsum('...i,...ij,...j', A, B, A)
我已经尝试过两次单独的 np.einsum() 调用,但这给了我相同的性能:
np.einsum('...i, ...i', np.einsum('...i,...ij', A, B), A)
我也已经尝试过 @ 运算符并添加了一些额外的轴,但这也没有让它更快:
(A[:, None]@B@A[...,None]).squeeze()
我试图让它与 np.inner()、np.dot()、np.tensordot() 和 np.vdot() 一起工作,但这些从来没有给我相同的结果,所以我不能比较一下。
还有其他想法吗?有什么方法可以让我的表现更好吗?
我已经快速了解了 Numba,但由于 Numba 不支持 np.einsum() 和许多其他 NumPy 函数,我将不得不重写很多代码。
【问题讨论】:
-
您可以尝试在
np.einsum中设置optimize=True吗? -
我已经试过了。但不幸的是,性能根本没有变化。
-
如果这个函数是你的主要瓶颈,只需在 Numba 中编写 einsum 表达式就足够了。 (已经尝试过 -> 并行化比使用 optimize=True 的 einsum 快大约 10 倍)
-
@max9111 我是 Numba 的绝对初学者:我该怎么做?将重写的 np.einsum 表达式放入一个额外的函数中,然后为其添加一个 Numba 装饰器?
-
你试过 opt_einsum 吗?