【问题标题】:Any chance of making this faster? (numpy.einsum)有没有机会让它更快? (numpy.einsum)
【发布时间】:2020-08-06 22:59:55
【问题描述】:

我正在尝试将三个数组 (A x B x A) 与维度 (19000, 3) x (19000, 3, 3) x (19000, 3) 相乘,这样最后我得到一个大小为 (19000) 的一维数组,所以我只想沿最后一维/二维相乘。

我已经让它与 np.einsum() 一起工作,但我想知道是否有任何方法可以加快速度,因为这是我整个代码的瓶颈。

np.einsum('...i,...ij,...j', A, B, A)

我已经尝试过两次单独的 np.einsum() 调用,但这给了我相同的性能:

np.einsum('...i, ...i', np.einsum('...i,...ij', A, B), A)

我也已经尝试过 @ 运算符并添加了一些额外的轴,但这也没有让它更快:

(A[:, None]@B@A[...,None]).squeeze()

我试图让它与 np.inner()、np.dot()、np.tensordot() 和 np.vdot() 一起工作,但这些从来没有给我相同的结果,所以我不能比较一下。

还有其他想法吗?有什么方法可以让我的表现更好吗?

我已经快速了解了 Numba,但由于 Numba 不支持 np.einsum() 和许多其他 NumPy 函数,我将不得不重写很多代码。

【问题讨论】:

  • 您可以尝试在np.einsum 中设置optimize=True 吗?
  • 我已经试过了。但不幸的是,性能根本没有变化。
  • 如果这个函数是你的主要瓶颈,只需在 Numba 中编写 einsum 表达式就足够了。 (已经尝试过 -> 并行化比使用 optimize=True 的 einsum 快大约 10 倍)
  • @max9111 我是 Numba 的绝对初学者:我该怎么做?将重写的 np.einsum 表达式放入一个额外的函数中,然后为其添加一个 Numba 装饰器?
  • 你试过 opt_einsum 吗?

标签: python numpy


【解决方案1】:

你可以使用 Numba

一开始,看看 np.einsum 做了什么总是一个好主意。对于optimize==optimal,找到一种减少FLOPs 的收缩方式通常非常好。在这种情况下,实际上只有很小的优化可能,并且中间数组相对较大(我会坚持使用幼稚的版本)。还应该提到的是,尺寸非常小(固定?)的收缩是一种非常特殊的情况。这也是为什么在这里很容易超越 np.einsum 的原因(展开等...,如果编译器知道循环仅包含 3 个元素,它就会这样做)

import numpy as np

A=np.random.rand(19000, 3)
B=np.random.rand(19000, 3, 3)

print(np.einsum_path('...i,...ij,...j', A, B, A,optimize="optimal")[1])

"""
  Complete contraction:  si,sij,sj->s
         Naive scaling:  3
     Optimized scaling:  3
      Naive FLOP count:  5.130e+05
  Optimized FLOP count:  4.560e+05
   Theoretical speedup:  1.125
  Largest intermediate:  5.700e+04 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   3                  sij,si->js                                 sj,js->s
   2                    js,sj->s                                     s->s

"""

Numba 实现

import numba as nb

#si,sij,sj->s
@nb.njit(fastmath=True,parallel=True,cache=True)
def nb_einsum(A,B):
    #check the input's at the beginning
    #I assume that the asserted shapes are always constant
    #This makes it easier for the compiler to optimize 
    assert A.shape[1]==3
    assert B.shape[1]==3
    assert B.shape[2]==3

    #allocate output
    res=np.empty(A.shape[0],dtype=A.dtype)

    for s in nb.prange(A.shape[0]):
        #Using a syntax like that is also important for performance
        acc=0
        for i in range(3):
            for j in range(3):
                acc+=A[s,i]*B[s,i,j]*A[s,j]
        res[s]=acc
    return res

时间

#warmup the first call is always slower 
#(due to compilation or loading the cached function)
res=nb_einsum(A,B)
%timeit nb_einsum(A,B)
#43.2 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A,optimize=True)
#450 µs ± 8.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.einsum('...i,...ij,...j', A, B, A)
#977 µs ± 4.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
np.allclose(np.einsum('...i,...ij,...j', A, B, A,optimize=True),nb_einsum(A,B))
#True

【讨论】:

  • 非常感谢!我一定会研究您的 Numba 实现并稍后尝试!
  • 好吧,看来我完全错了。真正的瓶颈不是 np.einsum 而是我的数组 B,它实际上是另一个数组的倒数,我们称之为 C。所以实际情况是:np.einsum('...i,...ij, ...j', A, inv(C), A) 逆计算是这里真正的瓶颈!在谷歌搜索了很多之后我发现,我什至不需要逆,而是可以只使用 np.solve(C, A)。所以最后我的实现是: np.einsum('...i,...i', A, np.linalg.solve(C, A)) 这让我的速度提高了大约 30 倍!
  • 我仍然会将您的答案标记为解决方案。非常感谢! :-)
猜你喜欢
  • 2018-01-05
  • 2017-03-26
  • 2018-05-05
  • 2019-08-06
  • 1970-01-01
  • 2022-06-14
  • 1970-01-01
  • 2020-03-08
  • 1970-01-01
相关资源
最近更新 更多