【发布时间】:2020-06-17 15:17:51
【问题描述】:
我试图运行一个简单的线性回归,但我在尝试训练时出错。
图像的大小是数据火车的形状print(dataset_train[0][0].shape) 显示给我torch.Size([3, 227, 227])
size_of_image=3*227*227
class linearRegression(nn.Module):
def __init__(self, inputSize, outputSize):
super(linearRegression, self).__init__()
self.linear = nn.Linear(inputSize, outputSize)
def forward(self, x):
out = self.linear(x)
return out
model = linearRegression(size_of_image, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = torch.nn.CrossEntropyLoss()
trainloader = DataLoader(dataset = dataset_train, batch_size = 1000)
for epoch in range(5):
for x, y in trainloader:
yhat = model(x)
loss = criterion(yhat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
我试图理解错误的含义,但我没有找到解决方案,有人可以帮助我吗?
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-44-6f00f9272a22> in <module>
1 for epoch in range(5):
2 for x, y in trainloader:
----> 3 yhat = model(x)
4 loss = criterion(yhat, y)
5 optimizer.zero_grad()
~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
<ipython-input-21-d20eb6e0c349> in forward(self, x)
5
6 def forward(self, x):
----> 7 out = self.linear(x)
8 return out
~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
85
86 def forward(self, input):
---> 87 return F.linear(input, self.weight, self.bias)
88
89 def extra_repr(self):
~/PycharmProjects/estudios/venv/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1610 ret = torch.addmm(bias, input, weight.t())
1611 else:
-> 1612 output = input.matmul(weight.t())
1613 if bias is not None:
1614 output += bias
Im RuntimeError: size mismatch, m1: [681000 x 227], m2: [154587 x 1] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41
【问题讨论】:
-
你能
print(x.size())吗?另外,您需要将图像展平,使其大小为[batch_size, 3*227*227],我想现在是[batch_size, 3, 227, 227]...