【发布时间】:2020-07-21 14:59:39
【问题描述】:
如何广播以将这两个矩阵相乘?
x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)
输出应该是:
(10, 120, 180, 64) == (N, H, W, Y)
【问题讨论】:
标签: python pytorch broadcast torch
如何广播以将这两个矩阵相乘?
x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)
输出应该是:
(10, 120, 180, 64) == (N, H, W, Y)
【问题讨论】:
标签: python pytorch broadcast torch
我假设x 是某种批处理示例,w 矩阵是相应的权重。在这种情况下,您可以简单地这样做:
out = x @ w.T
这是一个张量乘法,而不是逐个元素的乘法。你不能做元素乘法来得到这样的形状,这个操作是没有意义的。你所能做的就是以某种方式unsqueeze这两个矩阵,使其形状可广播,并在你不想要的维度上应用一些操作,如下所示:
x : torch.Size([10, 120, 180, 30, 1])
W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well
在这样的unsqueezing 之后,您可以沿着第三个dim 执行x*w 和sum 或mean 以获得所需的形状。
为清楚起见,两种方式并不等效。
【讨论】: