【问题标题】:Loading custom dataset in pytorch在 pytorch 中加载自定义数据集
【发布时间】:2019-05-23 12:38:28
【问题描述】:

通常,当我们在pytorch中加载数据时,我们会执行以下操作

for x, y in dataloaders:
    # Do something

但是,在这个名为 MusicNet 的数据集中,他们像这样声明自己的数据集和数据加载器

train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)
test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**kwargs)

然后他们像这样加载数据

with train_set, test_set:
    for i, (x, y) in enumerate(train_loader):
        # Do something

问题 1

我不明白为什么没有with train_set, test_set 这一行代码就不能工作。

问题 2

另外,我如何访问数据?

我试过了

train_set.access(2560,0)

with train_set, test_set:
    x, y = train_set.access(2560,0)

他们要么给我一条错误消息,例如

KeyError Traceback(最近一次调用最后一次)在 ----> 1 train_set.access(2560,0)

/workspace/raven_data/AMT/MusicNet/pytorch_musicnet/musicnet.py 访问(self,rec_id,s,shift,jitter)106 107 if self.mmap: --> 108 x = np.frombuffer(self.records[rec_id][0][ssz_float:int(s+scaleself.window)*sz_float], dtype=np.float32).copy() 109 else: 110 fid,_ = self.records[rec_id]

密钥错误:2560

或者给我一个空的xy

【问题讨论】:

    标签: python python-3.x pytorch


    【解决方案1】:

    问题一

    我不明白为什么没有with train_set, test_set 这一行代码就不能工作。

    为了能够将torch.utils.data.DataLoader自定义 数据集设计一起使用,您必须创建一个数据集类,该类是torch.utils.data.Dataset 的子类 (并实现特定功能) 并将其传递给数据加载器,即使他们这么说:

    所有其他数据集都应该对其进行子类化。所有子类都应该覆盖__len__,它提供了数据集的大小,和__getitem__,支持从0到len(self)独占的整数索引。

    这是发生在:

    train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)
    
    test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**k
    

    如果你查看他们的musicnet.MusicNet,你会发现他们这样做了。

    问题2

    另外,我如何访问数据?

    有可能的方法:

    从数据集中获取一个批次,您可以这样做:

    batch = next(iter(train_loader))
    

    访问整个数据集(尤其是在您的示例中)

    dataset = train_loader.dataset.records
    

    .records 是可能因数据集而异的部分,我说.records 因为这是我在here 中找到的)

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-09-12
      • 2019-12-30
      • 2021-11-09
      • 1970-01-01
      • 2019-11-08
      • 2020-02-18
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多