【问题标题】:Expected input batch_size (500) to match target batch_size (1000)预期输入 batch_size (500) 与目标 batch_size (1000) 匹配
【发布时间】:2021-03-09 18:46:07
【问题描述】:

我正在尝试在 PyTorch 中使用 MNIST 数据训练 CNN。但是,我得到 ValueError: Expected input batch_size (500) to match target batch_size (1000)。 当我在下面的代码中运行 test() 命令时会发生这种情况。我已经查找了此问题的解决方案,但没有一个可以帮助解决此问题。

我的代码如下:

n_epochs = 20
batch_size_train = 64
batch_size_test = 1000
learning_rate = 1e-4
log_interval = 50

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=1)
        self.fc1 = nn.Linear(9216, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 9216)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
    
    def loss_function(self, out, target):
        return F.cross_entropy(out, target)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

network = Net()
network.apply(init_weights)
network.cuda()

optimizer = optim.Adam(network.parameters(), lr=1e-4)

def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data = data.cuda()
    target = target.cuda()
    optimizer.zero_grad()
    output = network(data)
    loss = network.loss_function(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data = data.cuda()
      target = target.cuda()
      target = target.view(batch_size_test)
      output = network(data)
      test_loss += network.loss_function(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

test()
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()

完整的错误日志:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-ef6e122ea50c> in <module>()
----> 1 test()
      2 for epoch in range(1, n_epochs + 1):
      3   train(epoch)
      4   test()

3 frames
<ipython-input-9-23a4b65d1ae9> in test()
      9       target = target.view(batch_size_test)
     10       output = network(data)
---> 11       test_loss += network.loss_function(output, target).item()
     12       pred = output.data.max(1, keepdim=True)[1]
     13       correct += pred.eq(target.data.view_as(pred)).sum()

<ipython-input-5-d97bf44ef6f0> in loss_function(self, out, target)
     91 
     92     def loss_function(self, out, target):
---> 93         return F.cross_entropy(out, target)

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2466     if size_average is not None or reduce is not None:
   2467         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2468     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2469 
   2470 

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2260     if input.size(0) != target.size(0):
   2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
   2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (500) to match target batch_size (1000).

请告诉我如何解决此问题。 谢谢, 文尼

【问题讨论】:

  • test_loss += network.loss_function(output, target).item() 是否引发此错误?
  • 我将完整的错误日志添加到问题描述@Ivan。简短的回答 - 是的,那是 test() 中导致错误的行。
  • 什么是test_loader
  • test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('/files/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.5,), (0.5,)) ])), batch_size=batch_size_test, shuffle=True) @SergiiDymchenko
  • 我设法重现了错误,现在正在查看它。

标签: python pytorch conv-neural-network mnist


【解决方案1】:

您的数据具有以下形状[batch_size, c=1, h=28, w=28]batch_size 等于火车的 64 和测试集的 1000,但这没有任何区别,我们不应该处理第一个暗淡。

要使用F.cross_entropy,你必须提供一个大小为[batch_size, nb_classes]的张量,这里nb_classes是10。所以你的模型的最后一层应该总共有10个神经元。

附带说明,在使用此标准时,您不应在模型的输出中使用 F.log_softmax请参阅 here)。

此标准将 log_softmax 和 nll_loss 组合在一个函数中。

但这不是问题。问题是您的模型没有输出 [batch_size, 10] 张量。问题在于您使用了view:张量从torch.Size([64, 128, 6, 6]) 变为torch.Size([32, 9216])。您基本上说过“在 dim=1 上将所有内容压缩到总共 9216 (128*6*6*64/2) 并让其余的 (32) 保持在 dim=0 上”。这是不希望的,因为您弄乱了批次。 在您的 CNN 层之后,在此特定实例中使用 Flatten 层更容易。这将使每个通道的所有值变平。确保使用start_dim=1 保留第一个维度。

这是一个示例,旨在显示,层是随机,但代码会运行。您应该根据自己的喜好调整内核大小、通道数等!

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=8)
        self.fc1 = nn.Linear(128, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

【讨论】:

    猜你喜欢
    • 2021-04-03
    • 2020-08-25
    • 2021-09-04
    • 2021-03-06
    • 2021-04-15
    • 2019-11-05
    • 2020-01-23
    • 2021-01-29
    • 2019-07-22
    相关资源
    最近更新 更多