【问题标题】:Pytorch - Can not slice torchvision MNIST datasetPytorch - 无法切片 torchvision MNIST 数据集
【发布时间】:2019-06-12 14:09:38
【问题描述】:

在Pytorch中,当使用torchvision的MNIST数据集时,我们可以得到一个数字如下:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = datasets.MNIST(root='../../../_data/mnist',train=True,download=True,
                          transform=tsfm)

digit_12 = mnist_ds[12]

虽然可以对许多数据集进行切片,但我们不能对这个进行切片:

>>> digit_12_to_14 = mnist_ds[12:15]
ValueError: Too many dimensions: 3 > 2.

这是由于getItem() 中的Image.fromarray()

是否可以在不使用 Dataloader 的情况下使用 MNIST 数据集?


PS:我想避免使用 Dataloader 的原因是一次将一批发送到 GPU 会减慢训练速度。我更喜欢一次将整个数据集发送到 GPU。为此,我需要访问整个 transformed 数据集。

【问题讨论】:

标签: python pytorch dataset slice


【解决方案1】:

Dataset 接口只需要这个

所有子类都应覆盖提供数据集大小的__len____getitem__,支持从0len(self) 范围内的整数索引。

这显然没有提到切片 - 其他数据集的切片行为是一个额外的功能。如果您想一次获取全部数据,您可以查找 implementation 并使用 __init__ 末尾定义的 mnist.datamnist.targets 张量。

如果要转换数据,可以使用

data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)

或一次全部转换 mnist.data 张量(尽管这不适用于 torchvision.transform 转换)。

【讨论】:

  • 感谢您对切片的说明。 mnist.data 和 mnist.targets 不会被转换。应该使用 getIem() 来获取转换后的图像
  • 我的错,已更新。本质上,使用 Fabio Perez 的方法
【解决方案2】:

到目前为止,我找到了 2 个将 torchvision MNIST 数据集转换为张量的解决方案。第一个来自 Fábio Perez 评论:

print("\nFirst...")
st = time()
x_all_ts = torch.tensor([mnist_ds[i][0].numpy() for i in range(0, len(mnist_ds))])
t_all_ts = mnist_ds.train_labels
print(f"{time()-st}   images:{x_all_ts.size()}  targets:{t_all_ts.size()} ")

print("\nSecond...")
st = time()
mnist_dl = DataLoader(dataset=mnist_ds, batch_size=len(mnist_ds))
x_all_ts2, t_all_ts2 = list(mnist_dl)[0]
print(f"{time()-st}   images:{x_all_ts2.size()}  targets:{t_all_ts2.size()} ")


First...
19.573785066604614   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 
Second...
16.826476573944092   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 

如果你找到更好的请告诉我。

【讨论】:

    【解决方案3】:

    您可以使用 torch.utils.data.Subset() 获取基于索引的火炬切片 Dataset 例如:

    import torch.utils.data as data_utils
    
    indices = torch.arange(12,15)
    mnist_12to14 = data_utils.Subset(tr, indices)
    

    【讨论】:

    • 很棒的发现,谢谢 :)
    猜你喜欢
    • 2019-07-23
    • 1970-01-01
    • 1970-01-01
    • 2021-06-20
    • 1970-01-01
    • 2021-09-07
    • 2019-05-04
    • 2023-02-15
    • 2020-02-15
    相关资源
    最近更新 更多