【问题标题】:PyTorch Dataset / Dataloader batchingPyTorch 数据集/数据加载器批处理
【发布时间】:2020-10-09 04:42:58
【问题描述】:

对于在时间序列数据上实现 PyTorch 数据管道的“最佳实践”,我有点困惑。

我有一个使用自定义 DataLoader 读取的 HD5 文件。似乎我应该将数据样本作为 (features,targets) 元组返回,每个元组的形状为 (L,C),其中 L 是 seq_len,C 是通道数 - 即不要在数据加载器中进行批处理,只是作为表格返回。

PyTorch 模块似乎需要批量调暗,即 Conv1D 需要 (N, C, L)。

我的印象是DataLoader 类会预先设置批处理维度,但事实并非如此,我正在获取数据形状 (N,L)。

dataset = HD5Dataset(args.dataset)

dataloader = DataLoader(dataset,
                        batch_size=N,
                        shuffle=True,
                        pin_memory=is_cuda,
                        num_workers=num_workers)

for i, (x, y) in enumerate(train_dataloader):
    ...

在上面的代码中,x 的形状是 (N,C) 而不是 (1,N,C),这导致下面的代码(来自公共 git 存储库)在第一行失败。

def forward(self, x):
    """expected input shape is (N, L, C)"""
    x = x.transpose(1, 2).contiguous() # input should have dimension (N, C, L)

文档说明启用自动批处理时 它总是预先添加一个新维度作为批处理维度,这让我相信自动批处理已禁用 strong> 但我不明白为什么?

【问题讨论】:

  • “我得到的数据形状是 (N,L)”“x 的形状是 (N,C)”这两个陈述是矛盾的。你有一个错字吗? dataset的形状是什么?

标签: python pytorch pytorch-dataloader


【解决方案1】:

我发现了一些似乎可行的方法,一种选择似乎是使用 DataLoader 的collate_fn,但更简单的选择是使用BatchSampler,即

dataset = HD5Dataset(args.dataset)
train, test = train_test_split(list(range(len(dataset))), test_size=.1)

train_dataloader = DataLoader(dataset,
                        pin_memory=is_cuda,
                        num_workers=num_workers,
                        sampler=BatchSampler(SequentialSampler(train),batch_size=len(train), drop_last=True)
                        )

test_dataloader = DataLoader(dataset,
                        pin_memory=is_cuda,
                        num_workers=num_workers,
                        sampler=BatchSampler(SequentialSampler(test),batch_size=len(test), drop_last=True)
                        )

for i, (x, y) in enumerate(train_dataloader):
    print (x,y)

这会将数据集 dim (L, C) 转换为单个批次的 (1, L, C)(效率不是特别高)。

【讨论】:

    【解决方案2】:

    如果您有一个由张量对 (x, y) 组成的数据集,其中每个 x 的形状为 (C,L),那么:

    N, C, L = 5, 3, 10
    dataset = [(torch.randn(C,L), torch.ones(1)) for i in range(50)]
    dataloader = data_utils.DataLoader(dataset, batch_size=N)
    
    for i, (x,y) in enumerate(dataloader):
        print(x.shape)
    

    将生产(50/N)=10批形状(N,C,L)x

    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    torch.Size([5, 3, 10])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-08-25
      • 2020-03-14
      • 2022-01-23
      • 1970-01-01
      • 1970-01-01
      • 2020-01-17
      • 1970-01-01
      相关资源
      最近更新 更多