【发布时间】:2019-06-12 14:09:38
【问题描述】:
在Pytorch中,当使用torchvision的MNIST数据集时,我们可以得到一个数字如下:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
tsfm = transforms.Compose([transforms.Resize((16, 16)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
transform=tsfm)
digit_12 = mnist_ds[12]
虽然可以对许多数据集进行切片,但我们不能对这个进行切片:
>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.
这是由于getItem() 中的Image.fromarray()。
是否可以在不使用 Dataloader 的情况下使用 MNIST 数据集?
PS:我想避免使用 Dataloader 的原因是一次将一批发送到 GPU 会减慢训练速度。我更喜欢一次将整个数据集发送到 GPU。为此,我需要访问整个 transformed 数据集。
【问题讨论】:
-
[mnist_ds[i] for i in range(12,15)]怎么样? -
谢谢。这有效
标签: python pytorch dataset slice