【问题标题】:Cannot Iterate through PyTorch MNIST dataset无法遍历 PyTorch MNIST 数据集
【发布时间】:2019-07-21 16:18:41
【问题描述】:

我正在尝试在 Pytorch 中加载 MNIST 数据集,并使用内置的数据加载器来迭代训练示例。但是,在迭代器上调用 next() 时出现错误。 CIFAR10 没有这个问题。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128

dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dataiter = iter(dataloader)
dataiter.next() # ERROR
# RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

我正在使用 Python 3.7.3 和 PyTorch 1.1.0

【问题讨论】:

    标签: python-3.x pytorch


    【解决方案1】:

    MNIST 数据集由灰度图像组成,即每个图像只有1 通道,而CIFAR10 数据集由彩色图像组成,即每个图像都有3 通道。

    因此,如果是MNIST 数据集,请将transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 替换为transforms.Normalize([0.5], [0.5])

    【讨论】:

      【解决方案2】:

      您正在尝试使用

      标准化 1 通道图像
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
      

      这不起作用,并导致您提到的错误。您应该重新考虑您的任务需要哪些转换。

      【讨论】:

        猜你喜欢
        • 2019-06-12
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2012-06-05
        • 2019-05-04
        • 2023-02-15
        • 2018-05-26
        相关资源
        最近更新 更多