【发布时间】:2018-05-06 01:01:05
【问题描述】:
我有一个网络,我想在一些数据集上进行训练(例如,CIFAR10)。我可以通过
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
我的问题如下:假设我想做几个不同的训练迭代。假设我想首先在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推。为此,我需要能够访问这些图像。不幸的是,trainset 似乎不允许此类访问。也就是说,尝试执行trainset[:1000] 或更一般的trainset[mask] 会引发错误。
我可以这样做
trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]
然后
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了 trainset.train_data,所以我需要重新定义 trainset)。有什么办法可以避免吗?
理想情况下,我想要一些“等价于”的东西
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
shuffle=True, num_workers=2)
【问题讨论】:
标签: python machine-learning neural-network torch pytorch