【问题标题】:Pytorch NN error: Expected input batch_size (64) to match target batch_size (30)Pytorch NN 错误:预期输入 batch_size (64) 与目标 batch_size (30) 匹配
【发布时间】:2021-01-29 06:03:39
【问题描述】:

我目前正在训练一个神经网络来对食物图像的食物组进行分类,从而产生 5 个输出类。但是,每当我开始训练网络时,都会出现以下错误:

ValueError: Expected input batch_size (64) to match target batch_size (30).

这是我的神经网络定义和训练代码。我真的很需要帮助,我对 pytorch 比较陌生,无法弄清楚我的代码中到底有什么问题。谢谢!

#Define the Network Architechture

model = nn.Sequential(nn.Linear(7500, 4950),
                      nn.ReLU(),
                      nn.Linear(4950, 1000),
                      nn.ReLU(),
                      nn.Linear(1000, 250),
                      nn.ReLU(),
                      nn.Linear(250, 5),
                      nn.LogSoftmax(dim = 1))


#Define loss
criterion = nn.NLLLoss()

#Initial forward pass
images, labels = next(iter(trainloader))
images = images.view(images.shape[0], -1)
print(images.shape)

logits = model(images)
print(logits.size)
loss = criterion(logits, labels)
print(loss)

#Define Optimizer
optimizer = optim.SGD(model.parameters(), lr = 0.01)

训练网络:

epochs = 10

for e in range(epochs):
    running_loss = 0
    for image, labels in trainloader:
        #Flatten Images
        images = images.view(images.shape[0], -1)
        #Set gradients to 0
        optimizer.zero_grad()

        #Output
        output = model(images)
        loss = criterion(output, labels) #Where the error occurs
        loss.backward()

        #Gradient Descent Step
        optimizer.step()
        running_loss += loss.item()
    else:
        print(f"Training loss: {running_loss/len(trainloader)}")

【问题讨论】:

    标签: python neural-network pytorch image-classification


    【解决方案1】:

    错误出现在“for image, labels in trainloader:”这一行(应该是图像)。修好了,模型现在训练好了。

    【讨论】:

      【解决方案2】:

      不是 100% 确定,但我认为错误在这一行:

      nn.Linear(7500, 4950)
      

      除非您绝对确定您的输入是 7500,否则请使用 1 而不是 7500。请记住,第一个值将始终是您的输入大小。通过输入 1,您将确保您的模型可以处理任何大小的图像。

      顺便说一下,PyTorch 有一个 flatten 功能。使用nn.Flatten 而不是images.view(),因为您不想犯任何形状错误并必然浪费更多时间。

      您犯的另一个小错误是您继续在 for 循环中使用 images and image 作为变量和参数。这是非常糟糕的做法,因为每当有人阅读您的代码时,您都会使他们感到困惑。确保不要一遍又一遍地重复使用相同的变量。

      另外,您能否提供有关您的数据的更多信息?比如是灰度、image_size 等等。

      【讨论】:

      • 谢谢!事实证明,问题是一个错字,毕竟“对于图像,trainloader 中的标签:”(应该是图像)。感谢您的宝贵时间!
      猜你喜欢
      • 2019-11-05
      • 2019-07-22
      • 2020-01-23
      • 1970-01-01
      • 2022-01-18
      • 2021-04-03
      • 2021-03-09
      • 2020-08-25
      • 2021-09-04
      相关资源
      最近更新 更多