【问题标题】:Is there a numpy/scipy dot product, calculating only the diagonal entries of the result?是否有一个 numpy/scipy 点积,只计算结果的对角线条目?
【发布时间】:2013-01-23 09:00:51
【问题描述】:

想象有 2 个 numpy 数组:

> A, A.shape = (n,p)
> B, B.shape = (p,p)

通常 p 是一个较小的数字 (p

我正在做以下事情:

result = np.diag(A.dot(B).dot(A.T))

如您所见,我只保留了 n 个对角线条目,但是有一个计算出的中间 (n x n) 数组,仅保留对角线条目。

我希望有一个像 diag_dot() 这样的函数,它只计算结果的对角线条目,而不分配完整的内存。

结果是:

> result = diag_dot(A.dot(B), A.T)

是否有这样的预制功能,是否可以在无需分配中间 (n x n) 数组的情况下高效完成?

【问题讨论】:

    标签: python numpy scipy product


    【解决方案1】:

    我想我自己得到了它,但仍然会分享解决方案:

    因为只得到矩阵乘法的对角线

    > Z = N.diag(X.dot(Y))
    

    相当于X的行和Y的列的标量积的个体总和,前面的语句相当于:

    > Z = (X * Y.T).sum(-1)
    

    对于原始变量,这意味着:

    > result = (A.dot(B) * A).sum(-1)
    

    如果我错了,请纠正我,但应该是这样......

    【讨论】:

    • +1 智能代数总是比复杂的算法更好。
    • 如果有人不熟悉 numpy,这里的重点是 X.dot(Y) 运算符和 * 运算符之间的区别。 X.dot(Y) 表示线性代数中的常规矩阵乘积,而 X * Y 返回 X 和 Y 的条目之间的逐点乘积,因此 X 和 Y 需要具有相同的形状。
    • 虽然我喜欢你的回答,但我认为它会计算所有元素。这个问题的重点是让编译器明白我们只对对角线元素感兴趣:)
    【解决方案2】:

    您几乎可以通过numpy.einsum 获得您梦寐以求的任何东西。在你开始掌握它之前,它基本上看起来像是黑色巫术......

    >>> a = np.arange(15).reshape(5, 3)
    >>> b = np.arange(9).reshape(3, 3)
    
    >>> np.diag(np.dot(np.dot(a, b), a.T))
    array([  60,  672, 1932, 3840, 6396])
    >>> np.einsum('ij,ji->i', np.dot(a, b), a.T)
    array([  60,  672, 1932, 3840, 6396])
    >>> np.einsum('ij,ij->i', np.dot(a, b), a)
    array([  60,  672, 1932, 3840, 6396])
    

    编辑你实际上可以一次完成整个事情,这太荒谬了......

    >>> np.einsum('ij,jk,ki->i', a, b, a.T)
    array([  60,  672, 1932, 3840, 6396])
    >>> np.einsum('ij,jk,ik->i', a, b, a)
    array([  60,  672, 1932, 3840, 6396])
    

    编辑虽然你不想让它自己计算太多...添加了 OP 对其自己问题的答案以进行比较。

    n, p = 10000, 200
    a = np.random.rand(n, p)
    b = np.random.rand(p, p)
    
    In [2]: %timeit np.einsum('ij,jk,ki->i', a, b, a.T)
    1 loops, best of 3: 1.3 s per loop
    
    In [3]: %timeit np.einsum('ij,ij->i', np.dot(a, b), a)
    10 loops, best of 3: 105 ms per loop
    
    In [4]: %timeit np.diag(np.dot(np.dot(a, b), a.T))
    1 loops, best of 3: 5.73 s per loop
    
    In [5]: %timeit (a.dot(b) * a).sum(-1)
    10 loops, best of 3: 115 ms per loop
    

    【讨论】:

    • 我不知道这个功能 - 但现在肯定会这样做。谢谢分享!!!
    • 我相信'In [3]' 依赖于'dot' 是高度优化的 c 代码这一事实(废话?)但这确实构建了一个潜在的大型中间数组。
    • 我不确定 np.einsum('ij,ij->i', np.dot(a, b), a) 给出了什么,但它肯定与 Z = 的结果不同N.diag(X.dot(Y)) 提供点积的对角线元素(在本例中为 [15, 54, 111])。
    • 自 2013 年以来发生了一些变化。使用optimize=True 选项可以提高einsum 的性能。有一个新的面向批处理的matmul 可以解决这个问题:(A[:,None,:]@B@A[:,:,None])[:,0,0]。但是(A.dot(B) * A).sum(-1) 仍然是一个好方法。
    【解决方案3】:

    一个避免构建大型中间数组的行人答案是:

    result=np.empty([n,], dtype=A.dtype )
    for i in xrange(n):
        result[i]=A[i,:].dot(B).dot(A[i,:])
    

    【讨论】:

    • [n.] 不是有效的 Python。你的意思是A.shape
    • @wkschwartz 已修复。不,只是预分配结果数组。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2020-06-08
    • 1970-01-01
    • 1970-01-01
    • 2020-01-25
    • 1970-01-01
    • 2020-01-09
    • 2021-02-17
    相关资源
    最近更新 更多