【问题标题】:Get a smaller MNIST dataset in pytorch在 pytorch 中获取较小的 MNIST 数据集
【发布时间】:2023-02-15 18:36:12
【问题描述】:
这就是我加载数据集的方式,但数据集太大了。大约有 60k 张图像。所以我想将它限制在 1/10 以进行训练。有什么内置方法可以做到吗?
from torchvision import datasets
import torchvision.transforms as transforms
train_data = datasets.MNIST(
root='data',
train=True,
transform=transforms.Compose(
[transforms.ToTensor()]
),
download=True
)
print(train_data)
print(train_data.data.size())
print(train_data.targets.size())
loaders = {
'train': DataLoader(train_data,
batch_size=100),
}
【问题讨论】:
标签:
python
machine-learning
pytorch
【解决方案1】:
您可以使用 torch.utils.data.Subset 类,它接受输入数据集和一组索引,并仅选择与指定索引对应的元素:
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset
train_data = datasets.MNIST(
root='data',
train=True,
transform=transforms.Compose(
[transforms.Resize(32), transforms.ToTensor()]
),
download=True
)
# takes the first 10% images of MNIST train set
subset_train = Subset(train_data, indices=range(len(train_data) // 10))
【解决方案2】:
我看到@aretor 的回答不会涵盖所有数据点,只会涵盖 mnist 的起始数据点,即 0 和 1 类
因此使用下面的块
train = datasets.MNIST('../data', train=True, download=True, transform=transform)
part_tr = torch.utils.data.random_split(train, [tr_split_len, len(train)-tr_split_len])[0]
train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4)
【解决方案3】:
aretor 的答案不会打乱数据,而 Prajot 的答案会浪费地创建一个测试集。这是使用 SubsetRandomSampler 的 IMO 更好的解决方案:
from torch.utils.data import DataLoader, SubsetRandomSampler
K = 6000 # enter your length here
subsample_train_indices = torch.randperm(len(train_data))[:K]
train_loader = DataLoader(train_data, batch_size=batch_size, sampler=SubsetRandomSampler(subsample_train_indices))