【问题标题】:Python tensor matrix multiplyPython张量矩阵乘法
【发布时间】:2023-10-08 15:19:01
【问题描述】:

我有张量

A = 
[[[a,b],
  [c,d]],
 [[e,f],
  [g,h]]]

和矩阵

B = 
[[1,2],
 [3,4]]

我需要得到

C = 
[[a*1+e*2,b*1+f*2],
 [c*3+g*4,d*3+h*4]]

如何使用矩阵形式的 numpy 来做到这一点?我已经查看了np.tensordot(),但在这种情况下似乎没有帮助。

【问题讨论】:

    标签: python arrays numpy linear-algebra tensor


    【解决方案1】:

    OP 的问题可以使用张量表示法和所谓的Einstein summation convention 以标准格式重新表述

               A k i j  B i k   ⇒   C i j

    Numpy 有一个方便的实用函数来执行可以使用爱因斯坦求和约定描述的张量运算,不出所料地命名为numpy.einsum,它允许通过以下方式将张量符号直接映射到优化的 C 级循环一个准确反映张量符号的指令字符串'kij, ik -> ij'

    import numpy as np
    a = np.arange(8).reshape(2,2,2)+1
    b = np.arange(4).reshape(2,2)+1
    c = np.einsum('kij, ik -> ij', a, b)
    print(c)
    # [[11 14]
    #  [37 44]]
    

    优点 numpy.einsum

    1. 源代码记录了所执行操作的详细信息。
    2. np.einsum通常

      In [12]: import numpy as np 
          ...:  
          ...: i, j, k = 100, 320, 140 # just three largish numbers
          ...: a = np.random.random((k,i,j)) 
          ...: b = np.random.random((i,k))                                                      
      
      In [13]: %timeit np.einsum('kij,ik->ij', a, b)                                            
      7.47 ms ± 82.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
      
      In [14]: %timeit (a * b[None,:,:].T).sum(axis = 0)                                        
      49.3 ms ± 6.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
      

    【讨论】:

      【解决方案2】:

      你可以试试这个:

      >>> import numpy as np
      >>> a = np.arange(1,9).reshape(2,2,2)
      >>> a
      array([[[1, 2],
              [3, 4]],
      
             [[5, 6],
              [7, 8]]])
      >>> b = np.arange(1,5).reshape(2,2)
      >>> b
      array([[1, 2],
             [3, 4]])
      >>> (a * b[None,:,:].T).sum(axis = 0)
      array([[11, 14],
             [37, 44]])
      

      中间步骤如下所示:

      >>> b[None,:,:]
      array([[[1, 2],
              [3, 4]]])
      >>> b[None,:,:].T
      array([[[1],
              [3]],
      
             [[2],
              [4]]])
      

      【讨论】: