【问题标题】:PyTorch broadcast multiplication of 4D and 2D matrix?PyTorch广播4D和2D矩阵的乘法?
【发布时间】:2020-07-21 14:59:39
【问题描述】:

如何广播以将这两个矩阵相乘?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

输出应该是:

(10, 120, 180, 64) == (N, H, W, Y)

【问题讨论】:

    标签: python pytorch broadcast torch


    【解决方案1】:

    我假设x 是某种批处理示例,w 矩阵是相应的权重。在这种情况下,您可以简单地这样做:

    out = x @ w.T
    

    这是一个张量乘法,而不是逐个元素的乘法。你不能做元素乘法来得到这样的形状,这个操作是没有意义的。你所能做的就是以某种方式unsqueeze这两个矩阵,使其形状可广播,并在你不想要的维度上应用一些操作,如下所示:

    x : torch.Size([10, 120, 180, 30, 1])
    W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well
    

    在这样的unsqueezing 之后,您可以沿着第三个dim 执行x*wsummean 以获得所需的形状。

    为清楚起见,两种方式并不等效。

    【讨论】:

      猜你喜欢
      • 2015-01-07
      • 1970-01-01
      • 2020-05-01
      • 2019-05-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-04-29
      • 2018-05-24
      相关资源
      最近更新 更多