【发布时间】:2017-12-21 22:31:54
【问题描述】:
我正在尝试使用多个 torch.utils.data.DataLoaders 来创建应用了不同转换的数据集。目前,我的代码大致是
d_transforms = [
transforms.RandomHorizontalFlip(),
# Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
dataset = datasets.MNIST('./data',
train=train,
download=True,
transform=d_transforms[i]
loaders.append(
DataLoader(dataset,
shuffle=True,
pin_memory=True,
num_workers=1)
)
这可行,但速度极慢。 kernprof 表明我的代码中几乎所有时间都花在了类似
的行上x, y = next(iter(train_loaders[i]))
我怀疑这是因为我使用了DataLoader 的多个实例,每个实例都有自己的工作人员,它试图读取相同的数据文件。
我的问题是,有什么更好的方法来做到这一点?理想情况下,我会将torch.utils.data.DataSet 子类化并指定我想在采样时应用的变换,但这似乎是不可能的,因为__getitem__ 无法接受参数。
【问题讨论】:
标签: pytorch