【问题标题】:How to get a specific sample from pytorch DataLoader?如何从 pytorch DataLoader 获取特定样本?
【发布时间】:2020-07-07 00:20:49
【问题描述】:

在 Pytorch 中,有没有办法使用 torch.utils.data.DataLoader 类加载特定单个样本?我想用它做一些测试。

tutorial 使用

trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))

获取 随机 批样本。有没有办法,使用DataLoader,得到一个特定的样本?

干杯

【问题讨论】:

标签: pytorch


【解决方案1】:
  • 关闭shuffle 中的DataLoader
  • 使用batch_size 计算您要查找的所需样品所属的批次
  • 迭代到所需的批次

代码

import torch 
import numpy as np
import itertools

X= np.arange(100)
batch_size = 2

dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size, shuffle=False)
sample_at = 5
k = int(np.floor(sample_at/batch_size))

my_sample = next(itertools.islice(dataloader, k, None))
print (my_sample)

输出:

tensor([4, 5])

【讨论】:

  • 感谢您的回答@mujjiga,就像一个魅力!
  • 很好的答案,正是需要的。
【解决方案2】:

如果您想从数据集中获取特定的信号样本,您可以
你应该检查子集类。(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) 像这样:

indices =  [0,1,2]  # select your indices here as a list  
subset = torch.utils.data.Subset(train_set, indices)
trainloader = DataLoader(subset , batch_size =  16  , shuffle =False) #set shuffle to False 

for image , label in trainloader:
   print(image.size() , '\t' , label.size())
   print(image[0], '\t' , label[0]) # index the specific sample 

如果您想了解有关 Pytorch 数据加载实用程序的更多信息,这里是一个有用的链接 (https://pytorch.org/docs/stable/data.html)

【讨论】:

    猜你喜欢
    • 2021-11-26
    • 2019-11-04
    • 1970-01-01
    • 2019-05-03
    • 2021-02-15
    • 2018-10-11
    • 2021-07-09
    • 2020-07-14
    • 2022-01-24
    相关资源
    最近更新 更多