【问题标题】:Batch matrix multiplication in numpynumpy中的批量矩阵乘法
【发布时间】:2021-09-26 01:02:03
【问题描述】:

我有两个形状分别为[5, 5, 5][5, 5] 的numpy 数组ab。对于ab,形状中的第一个条目是批量大小。当我执行矩阵乘法选项时,我得到一个形状为[5, 5, 5] 的数组。 MWE如下。

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)

假设我要在批量大小上运行一个循环,即a[0] @ b[0].T,它将产生一个形状为[5, 1] 的数组。最后,如果我沿轴 1 连接所有结果,我将得到一个形状为 [5, 5] 的结果数组。下面的代码更好地描述了这些行。

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
    c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)

我可以在不使用循环的情况下获得上述功能吗?例如,PyTorch 提供了一个名为 torch.bmm 的函数来计算它。谢谢。

【问题讨论】:

  • 看看我的回答。希望对您有所帮助。

标签: python arrays python-3.x numpy matrix-multiplication


【解决方案1】:

b 中添加一个额外的维度以使矩阵乘法batch 兼容,并通过squeezing 删除末尾多余的最后一个维度:

c = np.matmul(a, b[:, :, None]).squeeze(-1)

或等效:

c = (a @ b[:, :, None]).squeeze(-1)

在您的示例中,通过将 b 重塑为 5 x 5 x 1,两者都使 ab 的矩阵乘法适当。

【讨论】:

  • 谢谢,虽然您的回答适用于batch_size=5 的情况,但恐怕对于其他批量大小,它会引发尺寸不匹配的错误。
  • 你能提供更一般的设置吗?只要a 的形状为B x M x NbB x N 形状,它就应该可以工作,这将使c形状 B x M.
  • 我有一个转置操作员在场,删除解决方案有效,谢谢!
【解决方案2】:

您可以使用 numpy einsum 解决此问题。

c = np.einsum('BNi,Bi ->BN', a, b)

Pytorch 也提供了这个 einsum 函数,但语法略有变化。所以你可以很容易地解决它。它也可以轻松处理其他形状。

那么您就不必担心转置或挤压操作了。它还可以节省内存,因为不会在内部创建现有矩阵的副本。

【讨论】:

  • 谢谢。但我认为我的问题np.einsum("BNi, Bi->BN", a, b) 有效。 +1
  • 我有一个与此相关的问题。但它被关闭了。所以我不得不重新问同样的问题。然后它也被关闭了。 stackoverflow.com/q/68424045/5462372你能帮忙解决一下吗?
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2023-01-11
  • 2014-03-01
相关资源
最近更新 更多