【问题标题】:Fast matrix multiplication of all combinations of matrix pairs in two lists两个列表中所有矩阵对组合的快速矩阵乘法
【发布时间】:2020-08-09 01:56:59
【问题描述】:

假设我们有两个 NxN 矩阵列表(L1M1 矩阵,L2M2 矩阵),我们希望找到一种有效的方法来乘以所有可能的 @987654326 @ 矩阵对。

简单示例:给定L1 = [a1, a2, a3]L2 = [b1, b2]作为输入,我们想要得到[a1.b1, a1.b2, a2.b1, a2.b2, a3.b1, a3.b2],其中.代表numpy的dot()

实际上,让我们马上把它做得更好:输入应该是两个形状为(M1, N, N)(M2, N, N) 的numpy 数组;输出的形状应该是(M1*M2, N, N)

问题:

  1. 如何在没有循环的情况下使用 numpy 有效地完成这项工作?我一直在尝试,但没有成功。
  2. 我们如何扩展它,以便除了将所有对相乘之外,我们还每次将固定的 NxN 矩阵 X 添加到每个乘积结果中。

起始代码:

import numpy as np
N  = 2
M1 = 3
M2 = 2
L1 = np.random.randn(M1, N, N)
L2 = np.random.randn(M2, N, N)
X  = np.random.randn(N, N)

【问题讨论】:

  • 通过广播,我们可以将 (m1,1,n,n) 与 (1,m2,n,n) 相乘以产生 (m1,m2,n,n)。

标签: python performance numpy numpy-ndarray


【解决方案1】:

代码:

import numpy as np

N  = 2
M1 = 3
M2 = 2
L1 = np.random.randn(M1, N, N)
L2 = np.random.randn(M2, N, N)

arr = np.dot(L1,L2).reshape(M1*M2,N,N)
arr = np.stack(arr, axis = 1).reshape(M1*M2,N,N)

print(arr.shape)
print(arr)

输出:

(6, 2, 2)
[[[-0.23801453  0.28045455]
  [-0.20878195  0.93019629]]

 [[ 0.31130404  2.03630303]
  [-0.41401161 -0.70762532]]

 [[ 2.66698736  0.34818929]
  [ 0.17368564 -0.82483313]]

 [[-0.45976939  0.25588863]
  [-1.1408594   0.74448629]]

 [[-1.98923364  1.49183873]
  [ 0.48897251 -0.46352228]]

 [[ 1.38875825 -0.21409729]
  [ 1.00406231 -0.65810354]]]

【讨论】:

  • 是的,我没有正确阅读这个问题,但这次我没有任何循环就完成了所有这些。
  • 你能解释一下你对问题 2 的确切要求吗?
  • 对于您的下一个问题,您只需添加数组即可。假设你要添加的数组是c,只需通过arr+c 将其添加到arr 中,全部由numpy 的广播功能完成
  • 太好了,谢谢。我编辑了你的答案。希望很快就会出现。
猜你喜欢
  • 1970-01-01
  • 2017-01-25
  • 1970-01-01
  • 1970-01-01
  • 2011-11-30
  • 1970-01-01
  • 2021-07-19
  • 1970-01-01
  • 2015-04-08
相关资源
最近更新 更多