【问题标题】:Why doesn't this input pass through this simple PyTorch model?为什么这个输入不通过这个简单的 PyTorch 模型?
【发布时间】:2021-05-10 08:30:23
【问题描述】:

我有一个输入/张量,其形状为:

torch.Size([256, 3, 28, 28])

(这里的批量大小为 256,3 通道,28x28 图像)

还有这样的模型:

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 28, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(28, 56, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 56 x 16 x 16

            nn.Conv2d(56, 112, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(112, 112, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 112 x 8 x 8

            nn.Conv2d(112, 224, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(224, 224, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # output: 224 x 4 x 4

            nn.Flatten(),
            nn.Linear(224 * 4 * 4, 896),
            nn.ReLU(),
            nn.Linear(896, 512),
            nn.ReLU(),
            nn.Linear(512, 2))

    def forward(self, xb):
        return self.network(xb)

当我尝试向前传递数据时,它失败了:

    ...
    return self.network(xb)
  File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/stark/anaconda3/envs/torch-env/lib/python3.8/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

我错过了什么?

谢谢!

【问题讨论】:

    标签: python python-3.x pytorch conv-neural-network


    【解决方案1】:

    nn.MaxPool2d(2, 2), # 输出:56 x 16 x 16

    这是错误的。原始输入的大小为 (256, 3, 28, 28)。您使用的卷积层和 ReLU 层不会改变批次、高度或宽度尺寸;他们只改变“渠道”维度。在最大池化层之前,张量大小为 (256, 56, 28, 28)。最大池化层的内核大小为 2,步幅为 2,因此它将高度和宽度都减半。所以这个最大池化层的输出大小为 (256, 56, 14, 14)。

    出于同样的原因,下一个最大池化层的输出大小为 (256, 112, 7, 7),最后一个最大池化层的输出大小为 (256, 224, 3, 3) .

    因此,您可以通过将输入大小更改为 (256, 3, 32, 32) 来解决此问题(如果这是一个选项),或者将第一个线性层更改为 nn.Linear(224 * 3 * 3, 896)

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-04-17
      • 2012-06-13
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-09-19
      相关资源
      最近更新 更多