【问题标题】:Pytorch Dataloader not spliting data into batchPytorch Dataloader 没有将数据拆分成批处理
【发布时间】:2020-06-11 08:05:19
【问题描述】:

我有这样的数据集类:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        return dlen
    def __getitem__(self, index):
        return self.data, self.label

然后我加载具有 [485, 1, 32, 32] 形状的图像数据集

train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485

然后我用DataLoader加载数据

train_loader = DataLoader(train_dataset, batch_size=32)

然后我迭代数据:

for epoch in range(num_epoch):
        for inputs, labels in train_loader:   
            print(inputs.shape)

输出打印torch.Size([32, 485, 1, 32, 32]),应该是torch.Size([32, 1, 32, 32])

谁能帮帮我?

【问题讨论】:

    标签: python machine-learning computer-vision pytorch


    【解决方案1】:

    __getitem__ 方法应该返回 1 个数据片段,你返回了所有数据片段。

    试试这个:

    class LoadDataset(Dataset):
        def __init__(self, data, label):
            self.data = data
            self.label = label
        def __len__(self):
            dlen = len(self.data)
            llen = len(self.label)  # different here
            return min(dlen, llen)  # different here
        def __getitem__(self, index):
            return self.data[index], self.label[index]  # different here
    

    【讨论】:

      猜你喜欢
      • 2022-01-23
      • 1970-01-01
      • 2019-07-29
      • 2019-04-27
      • 2021-10-23
      • 2021-02-11
      • 2020-10-09
      • 2019-10-07
      • 1970-01-01
      相关资源
      最近更新 更多