【问题标题】:Creating custom dataset in PyTorch在 PyTorch 中创建自定义数据集
【发布时间】:2020-02-18 03:32:16
【问题描述】:

问题

在 PyTorch 中,我正在尝试编写一个类,它可以使用 dataset.datadataset.label 等语法分别返回整个 datalabel。代码骨架如下所示:

class MyDataset(object):
  data = _get_data()
  label = _get_label()
  def __init__(self, dir, transforms):
    self.img_list = ... # all image paths loaded from dir
    # do something 

  def __getitem__(self):
    # do something
    return data, label

  def __len__(self):
    return len(self.img_list)

  def _get_data():
    # do something

  def _get_label():
    # do something

但是,当我使用dataset.datadataset.label 访问相应的变量时,没有返回任何内容。

我想知道为什么会出现这种情况以及如何解决这个问题。

编辑

感谢大家的关注。

我自己解决了这个问题。解决方案非常简单,它只是利用了类变量的属性。

class FaceDataset(object):
    # class variable
    data = None
    label = None

    def __init__(self, root, transforms=None):
        # read img_list from root
        img_list = ...
        self.transforms = ...
        FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
        FaceDataset.label = FaceDataset._get_label(self.img_list)

    @classmethod
    def _get_data(cls, img_list, transforms):
        data_list = []
        for img_path in img_list:
            data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
        return torch.stack(data_list, dim=0)

    @classmethod
    def _get_label(cls, img_list):
        label = torch.zeros(len(img_list))
        for i, img_path in enumerate(img_list):
            label[i] = ...
        return label

    def __getitem__(self, index):
        img_path = self.img_list[index]
        label = ...

        # read image from file
        data = Image.open(img_path)
        # apply transform defined in __init__
        data = self.transforms(data)

        return data, label

    def __len__(self):
        return len(self.img_list)

【问题讨论】:

  • 我不知道为什么人们不赞成我的问题。
  • 好,从你目前所做的开始!
  • 这个问题无法回答,因为我们无法简单地猜测_get_data()_get_label() 中的内容。此外,在 PyTorch 中,您应该始终为您的自定义数据集子类化 Dataset 类。
  • @Mat 它们返回图像的像素值和图像的相应标签(在我的例子中,图像中是否有人脸)。我不知道我可以直接继承torch.utils.data.Dataset,目前我刚刚创建了一个可迭代对象(比如mydataset),然后使用dataloader = torch.utils.data.DataLoader(mydataset) 之类的语法创建数据集。感谢您指出这一点。
  • 这里有几件事:首先,如果您找到问题的答案,您应该发布该答案,而不是在问题中编辑它。其次,我投票决定关闭它,因为它不清楚;这不是定义自定义Dataset 的方式,因此猜测行为是不可行的。

标签: python class object pytorch


【解决方案1】:

在 Python 中创建自定义数据集的“正常”方法已经在 SO 上得到了回答 here。恰好有一个官方 PyTorch tutorial 用于此。

举个简单的例子,你可以阅读 PyTorch MNIST 数据集代码here(此数据集在此 PyTorch example code 中用于进一步说明)。最后,您可以在此 torchvision datasets list 中找到其他数据集实现(单击数据集名称,然后单击数据集文档中的“源”按钮,以访问数据集的 PyTorch 实现)。

【讨论】:

  • 感谢您的指点。由于阅读了 2 年前出版的一本 pytorch 书籍,我以“不正常”的方式这样做了,这本书显然已经过时了。
  • 对于 ML 库来说,2 年肯定已经过时了!为了避免在有据可查的 PyTorch 使用上浪费精力,我建议您查看 PyTorch 教程 page。每个教程都使用页面左侧的标题链接,此列表中的第二个教程是您对此问题感兴趣的教程。享受吧!
猜你喜欢
  • 1970-01-01
  • 2018-07-29
  • 1970-01-01
  • 2020-09-28
  • 2021-07-21
  • 2017-09-12
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多