【问题标题】:PyTorch DataLoader adding extra dimension for TorchVision MNISTPyTorch DataLoader 为 TorchVision MNIST 添加额外维度
【发布时间】:2019-11-28 12:26:36
【问题描述】:

我是 PyTorch 的新手,并且一直在试验 DataLoader 类。 当我尝试加载 MNIST 数据集时,DataLoader 似乎在批处理维度之后添加了一个额外的维度。我不确定是什么导致了这种情况发生。

import torch
from torchvision.datasets import MNIST
from torchvision import transforms

if __name__ == '__main__':
    mnist_train = MNIST(root='./data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor()]))
    first_x = mnist_train.data[0]
    print(first_x.shape)  # expect to see [28, 28], actual [28, 28]

    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200)
    batch_x, batch_y = next(iter(train_loader))  # get first batch
    print(batch_x.shape)  # expect to see [200, 28, 28], actual [200, 1, 28, 28]
    # Where is the extra dimension of 1 from?

谁能解释一下这个问题?

【问题讨论】:

    标签: pytorch torchvision


    【解决方案1】:

    我猜这是输入图像的通道数。所以基本上是

    batch_x.shape = Batch-size, No of channels, Height of the image, Width of the image

    【讨论】:

    猜你喜欢
    • 2019-07-23
    • 2018-10-07
    • 2019-06-12
    • 2019-11-19
    • 2019-06-18
    • 2020-07-07
    • 2023-03-07
    • 2016-11-07
    • 1970-01-01
    相关资源
    最近更新 更多