【问题标题】:Pytorch - Concatenating Datasets before using DataloaderPytorch - 在使用 Dataloader 之前连接数据集
【发布时间】:2020-07-05 12:09:18
【问题描述】:

我正在尝试加载两个数据集并将它们都用于训练。

包版本:python 3.7; pytorch 1.3.1

可以单独创建data_loaders并按顺序训练它们:

from torch.utils.data import DataLoader, ConcatDataset


train_loader_modelnet = DataLoader(ModelNet(args.modelnet_root, categories=args.modelnet_categories,split='train', transform=transform_modelnet, device=args.device),batch_size=args.batch_size, shuffle=True)

train_loader_mydata = DataLoader(MyDataset(args.customdata_root, categories=args.mydata_categories, split='train', device=args.device),batch_size=args.batch_size, shuffle=True)

for e in range(args.epochs):
    for idx, batch in enumerate(tqdm(train_loader_modelnet)):
        # training on dataset1
    for idx, batch in enumerate(tqdm(train_loader_custom)):
        # training on dataset2

注意:MyDataset 是一个自定义数据集类,它实现了 def __len__(self): def __getitem__(self, index):。由于上述配置有效,看来这是实施还可以。

但理想情况下,我希望将它们组合成一个数据加载器对象。我根据 pytorch 文档尝试了此操作:

train_modelnet = ModelNet(args.modelnet_root, categories=args.modelnet_categories,
                          split='train', transform=transform_modelnet, device=args.device)
train_mydata = CloudDataset(args.customdata_root, categories=args.mydata_categories,
                             split='train', device=args.device)
train_loader = torch.utils.data.ConcatDataset(train_modelnet, train_customdata)

for e in range(args.epochs):
    for idx, batch in enumerate(tqdm(train_loader)):
        # training on combined

但是,在随机批次中,我得到以下“期望张量作为参数 0 中的元素 X,但得到一个元组”类型的错误。任何帮助将不胜感激!

>   40%|████      | 53/131 [01:03<02:00,  1.55s/it]
>  Traceback (mostrecent call last):   File
> "/home/chris/Programs/pycharm-anaconda-2019.3.4/plugins/python/helpers/pydev/pydevd.py",
> line 1434, in _exec
>     pydev_imports.execfile(file, globals, locals)  # execute the script   File
> "/home/chris/Programs/pycharm-anaconda-2019.3.4/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
>     exec(compile(contents+"\n", file, 'exec'), glob, loc)   File "/home/chris/Documents/4yp/Data/my_kaolin/Classification/pointcloud_classification_combinedset.py",
> line 83, in <module>
>     for idx, batch in enumerate(tqdm(train_loader)):   File "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/tqdm/std.py",
> line 1107, in __iter__
>     for obj in iterable:   File "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/torch/utils/data/dataloader.py",
> line 346, in __next__
>     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration   File
> "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py",
> line 47, in fetch
>     return self.collate_fn(data)   File "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py",
> line 79, in default_collate
>     return [default_collate(samples) for samples in transposed]   File "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py",
> line 79, in <listcomp>
>     return [default_collate(samples) for samples in transposed]   File "/home/chris/anaconda3/envs/4YP/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py",
> line 55, in default_collate
>     return torch.stack(batch, 0, out=out) TypeError: expected Tensor as element 3 in argument 0, but got tuple  

【问题讨论】:

    标签: python tensorflow machine-learning dataset pytorch


    【解决方案1】:

    如果你的问题没有错:你有如下训练集和开发集(以及它们相应的加载器)。

    train_set = CustomDataset(...)
    train_loader = DataLoader(dataset=train_set, ...)
    dev_set = CustomDataset(...)
    dev_loader = DataLoader(dataset=dev_set, ...)
    

    您想将它们连接起来以使用 train+dev 作为训练数据,对吗?如果是这样,您只需调用:

    train_dev_sets = torch.utils.data.ConcatDataset([train_set, dev_set])
    train_dev_loader = DataLoader(dataset=train_dev_sets, ...)
    

    train_dev_loader 是包含两个集合数据的加载器。

    现在,请确保您的数据具有相同的形状和相同的类型,即相同数量的特征或相同的类别/编号等。

    【讨论】:

      【解决方案2】:

      我猜这两个数据集有时会返回不同的类型。当数据是张量时,火炬将它们堆叠起来,它们最好是相同的形状。如果它们像字符串,torch 会用它们组成一个元组。所以这听起来像你的一个数据集有时会返回不是张量的东西。我会在您的数据集的输出中添加一些断言,以检查它是否在执行您想要的操作,或者使用 pdb 进行深入研究。

      【讨论】:

        【解决方案3】:

        添加到@ Leopd的答案,您可以使用Pytorch提供的@ 987654323 function @ 987654321。这个想法是,在collate_fn中,您将定义例子应该如何堆叠以制作批处理。由于您在Torch 1.3.1上,请确保您正在查看documentation的正确版本。

        让我知道这是否有帮助或者是否有任何后续问题:)

        【讨论】:

          猜你喜欢
          • 2021-05-26
          • 2019-09-19
          • 2018-12-22
          • 2019-04-27
          • 1970-01-01
          • 2021-04-18
          • 2019-05-17
          • 2022-01-18
          • 2020-03-10
          相关资源
          最近更新 更多