【问题标题】:Pytorch vsion size mismatch, m1Pytorch vsion 大小不匹配,m1
【发布时间】:2020-06-17 15:17:51
【问题描述】:

我试图运行一个简单的线性回归,但我在尝试训练时出错。

图像的大小是数据火车的形状print(dataset_train[0][0].shape) 显示给我torch.Size([3, 227, 227])


size_of_image=3*227*227

class linearRegression(nn.Module):
    def __init__(self, inputSize, outputSize):
        super(linearRegression, self).__init__()
        self.linear = nn.Linear(inputSize, outputSize)

    def forward(self, x):
        out = self.linear(x)
        return out

model = linearRegression(size_of_image, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = torch.nn.CrossEntropyLoss()         
trainloader = DataLoader(dataset = dataset_train, batch_size = 1000)
for epoch in range(5):
    for x, y in trainloader:
        yhat = model(x)
        loss = criterion(yhat, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()            

我试图理解错误的含义,但我没有找到解决方案,有人可以帮助我吗?

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-44-6f00f9272a22> in <module>
      1 for epoch in range(5):
      2     for x, y in trainloader:
----> 3         yhat = model(x)
      4         loss = criterion(yhat, y)
      5         optimizer.zero_grad()

~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-21-d20eb6e0c349> in forward(self, x)
      5 
      6     def forward(self, x):
----> 7         out = self.linear(x)
      8         return out

~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85 
     86     def forward(self, input):
---> 87         return F.linear(input, self.weight, self.bias)
     88 
     89     def extra_repr(self):

~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1610         ret = torch.addmm(bias, input, weight.t())
   1611     else:
-> 1612         output = input.matmul(weight.t())
   1613         if bias is not None:
   1614             output += bias

Im RuntimeError: size mismatch, m1: [681000 x 227], m2: [154587 x 1] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

【问题讨论】:

  • 你能print(x.size())吗?另外,您需要将图像展平,使其大小为[batch_size, 3*227*227],我想现在是[batch_size, 3, 227, 227]...

标签: python pytorch


【解决方案1】:

您需要将flatten 2D 输入图像转换为 1D 信号:
您的输入是形状为 1000-3-227-227(批量通道高度宽度)的 4D 张量。然而,nn.Linear 期望输入 2D 张量的形状批处理通道。

Youc forward 代码应该类似于:

def forward(self, x):
  flat_x = x.view(x.shape[0], -1)  # collapse all dimensions to the second one
  out = self.linear(flat_x)
  return out

【讨论】:

  • 我做了更改,但知道告诉我:RuntimeError: Assertion `cur_target >= 0 && cur_target
【解决方案2】:

linearRegression 中,您已将线性变换定义为:nn.Linear(3*227*227, 1),这意味着Linear 层需要3*227*227 输入特征,它将输出1 个特征。

但是,您将形状为 [1000, 3, 227, 227](batch-channel-height-width)的 4D 张量输入线性层,该层将最后一个维度视为特征维度。这意味着线性层正在获得 227 个输入特征,而不是 3*227*227。因此,您会收到以下错误。

RuntimeError: size mismatch, m1: [681000 x 227], m2: [154587 x 1]

请注意,线性层与形状为in_features x out_features 的权重矩阵相关联(在您的情况下,它是[154587 x 1])。 线性层的输入被展平为二维张量,在您的情况下,它是[1000*3*227 x 227] = [681000 x 227]

因此,尝试对形状为 [681000 x 227][154587 x 1] 的两个张量执行矩阵乘法会导致上述错误。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2020-02-07
    • 1970-01-01
    • 2021-05-13
    • 2019-06-10
    • 1970-01-01
    • 2019-05-05
    • 2019-06-16
    • 1970-01-01
    相关资源
    最近更新 更多