【发布时间】:2020-01-03 20:35:31
【问题描述】:
我想知道如何在 PyTorch 中使用 torch.utils.data.DataLoader,尤其是在多工人的情况下。
我发现DataLoader 的一批输出总是来自一个工人。
我希望 DataLoader 中有一个队列,它存储来自所有工作人员的数据,并且 DataLoader 将它们打乱在队列中以输出随机批处理数据。我认为这就是 Tensorflow 中tf.data.Dataset 的方式。
我们可以在 PyTorch 中实现类似的功能吗?我想通过使用多工作人员从大型序列化文件(如Tfrecord)加载数据集。在这种情况下,在一批中混合源文件,也就是混合worker的源,就很重要了。
请参考以下代码:
import random
import time
import torch
class MyDataset(torch.utils.data.Dataset):
def __len__(self):
return 50
def __getitem__(self, idx):
info = torch.utils.data.get_worker_info()
time.sleep(random.uniform(0, 1))
print("[{}]:{}".format(info.id, idx))
return idx, info.id
if __name__ == '__main__':
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
for batch in dataloader:
print(batch)
输出:
[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...
这里,[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])] 中的[0, 1, 2, 3, 4] 和[0, 0, 0, 0, 0] 表示该批次包含来自worker id 0 的索引0-th 到4-th 数据。
注意shuffle=True 并不能解决这个问题,它只会改变数据的索引。
在这种情况下,我想得到一个像:[tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])] 这样的批次。
【问题讨论】:
标签: pytorch dataloader