【发布时间】:2021-01-26 22:33:06
【问题描述】:
我尝试创建自定义数据集,但在显示某些图像时出现错误。这是我的数据集类和转换:
transforms = transforms.Compose([transforms.Resize(224,224)])
class MyDataset(Dataset):
def __init__(self, path, label, transform=None):
self.path = glob.glob(os.path.join(path, '*.jpg'))
self.transform = transform
self.label = label
def __getitem__(self, index):
img = io.imread(self.path[index])
img = torch.tensor(img)
labels = torch.tensor(int(self.label))
if self.transform:
img = self.transform(img)
return (img,labels)
def __len__(self):
return len(self.path)
这里是错误行:
images, labels = next(iter(train_loader))
【问题讨论】:
标签: pytorch