【发布时间】:2020-06-11 08:05:19
【问题描述】:
我有这样的数据集类:
class LoadDataset(Dataset):
def __init__(self, data, label):
self.data = data
self.label = label
def __len__(self):
dlen = len(self.data)
return dlen
def __getitem__(self, index):
return self.data, self.label
然后我加载具有 [485, 1, 32, 32] 形状的图像数据集
train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485
然后我用DataLoader加载数据
train_loader = DataLoader(train_dataset, batch_size=32)
然后我迭代数据:
for epoch in range(num_epoch):
for inputs, labels in train_loader:
print(inputs.shape)
输出打印torch.Size([32, 485, 1, 32, 32]),应该是torch.Size([32, 1, 32, 32]),
谁能帮帮我?
【问题讨论】:
标签: python machine-learning computer-vision pytorch