【问题标题】:How do you alter the size of a Pytorch Dataset? [duplicate]如何更改 Pytorch 数据集的大小?
【发布时间】:2017-12-05 00:25:32
【问题描述】:

假设我正在从 torchvision.datasets.MNIST 加载 MNIST,但我只想加载 10000 张图像,我将如何对数据进行切片以将其限制为仅一些数据点?我知道 DataLoader 是一个生成器,可以生成指定批量大小的数据,但是如何对数据集进行切片呢?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)

【问题讨论】:

    标签: python machine-learning dataset torch pytorch


    【解决方案1】:

    您可以使用torch.utils.data.Subset(),例如对于前 10,000 个元素:

    import torch.utils.data as data_utils
    
    indices = torch.arange(10000)
    tr_10k = data_utils.Subset(tr, indices)
    

    【讨论】:

    • 这个修改的是Dataset而不是DataLoader,很清晰。
    【解决方案2】:

    另一种切片数据集的快速方法是使用torch.utils.data.random_split()(PyTorch v0.4.1+ 支持)。它有助于将数据集随机拆分为给定长度的非重叠新数据集。

    所以我们可以有如下的东西:

    tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
    te = datasets.MNIST('../data', train=False, transform=transform)
    
    part_tr = torch.utils.data.random_split(tr, [tr_split_len, len(tr)-tr_split_len])[0]
    part_te = torch.utils.data.random_split(te, [te_split_len, len(te)-te_split_len])[0]
    
    train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
    test_loader = DataLoader(part_te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
    

    您可以在此处将tr_split_lente_split_len 分别设置为训练和测试数据集所需的分割长度。

    【讨论】:

      【解决方案3】:

      请务必注意,当您创建 DataLoader 对象时,它不会立即加载您的所有数据(这对于大型数据集是不切实际的)。它为您提供了一个迭代器,您可以使用它来访问每个样本。

      很遗憾,DataLoader 没有为您提供任何方法来控制您希望提取的样本数量。您将不得不使用切片迭代器的典型方法。

      最简单的事情(没有任何库)是在达到所需的样本数量后停止。

      nsamples = 10000
      for i, image, label in enumerate(train_loader):
          if i > nsamples:
              break
      
          # Your training code here.
      

      或者,您可以使用 itertools.islice 获取前 10k 个样本。像这样。

      for image, label in itertools.islice(train_loader, stop=10000):
      
          # your training code here.
      

      【讨论】:

      • 此方法的警告:如果您在循环变量epoch 上多次迭代train_loader,您可能已经使用了所有样本进行训练...因为@987654329 DataLoader 中的 @ 选项将为每个 epoch 打乱样本。
      • 我不断收到类似DataLoader worker (pid(s) 9579) exited unexpectedly 的错误(在 OSX 上)
      猜你喜欢
      • 2023-03-22
      • 2021-07-25
      • 2019-09-10
      • 1970-01-01
      • 2021-11-13
      • 2019-01-21
      • 2019-07-29
      • 1970-01-01
      • 2019-07-12
      相关资源
      最近更新 更多