【发布时间】:2020-02-18 03:32:16
【问题描述】:
问题
在 PyTorch 中,我正在尝试编写一个类,它可以使用 dataset.data 和 dataset.label 等语法分别返回整个 data 和 label。代码骨架如下所示:
class MyDataset(object):
data = _get_data()
label = _get_label()
def __init__(self, dir, transforms):
self.img_list = ... # all image paths loaded from dir
# do something
def __getitem__(self):
# do something
return data, label
def __len__(self):
return len(self.img_list)
def _get_data():
# do something
def _get_label():
# do something
但是,当我使用dataset.data 和dataset.label 访问相应的变量时,没有返回任何内容。
我想知道为什么会出现这种情况以及如何解决这个问题。
编辑
感谢大家的关注。
我自己解决了这个问题。解决方案非常简单,它只是利用了类变量的属性。
class FaceDataset(object):
# class variable
data = None
label = None
def __init__(self, root, transforms=None):
# read img_list from root
img_list = ...
self.transforms = ...
FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
FaceDataset.label = FaceDataset._get_label(self.img_list)
@classmethod
def _get_data(cls, img_list, transforms):
data_list = []
for img_path in img_list:
data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
return torch.stack(data_list, dim=0)
@classmethod
def _get_label(cls, img_list):
label = torch.zeros(len(img_list))
for i, img_path in enumerate(img_list):
label[i] = ...
return label
def __getitem__(self, index):
img_path = self.img_list[index]
label = ...
# read image from file
data = Image.open(img_path)
# apply transform defined in __init__
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.img_list)
【问题讨论】:
-
我不知道为什么人们不赞成我的问题。
-
好,从你目前所做的开始!
-
这个问题无法回答,因为我们无法简单地猜测
_get_data()和_get_label()中的内容。此外,在 PyTorch 中,您应该始终为您的自定义数据集子类化 Dataset 类。 -
@Mat 它们返回图像的像素值和图像的相应标签(在我的例子中,图像中是否有人脸)。我不知道我可以直接继承
torch.utils.data.Dataset,目前我刚刚创建了一个可迭代对象(比如mydataset),然后使用dataloader = torch.utils.data.DataLoader(mydataset)之类的语法创建数据集。感谢您指出这一点。 -
这里有几件事:首先,如果您找到问题的答案,您应该发布该答案,而不是在问题中编辑它。其次,我投票决定关闭它,因为它不清楚;这不是定义自定义
Dataset的方式,因此猜测行为是不可行的。
标签: python class object pytorch