【问题标题】:Pytorch customized dataloaderPytorch 自定义数据加载器
【发布时间】:2021-11-09 05:29:23
【问题描述】:

我正在尝试使用 pytorch-lightening 训练具有 MNIST 数据集的分类器。

import pytorch_lightning as pl
from torchvision import transforms
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader, random_split


class MNISTData(pl.LightningDataModule):

    def __init__(self, data_dir='./', batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def download(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)

使用MNISTData().setup()后,我得到了MNISTData().mnist_train, MNISTData().mnist_val, MNISTData().mnist_test,长度分别为55000、5000、10000,类型为torch.utils.data.dataset.Subset。

但是当我调用 dataloader w.r.t MNISTData().train_dataloader, MNISTData().val_dataloader, MNISTData().test_dataloader 时,我只得到包含 215、20、None 数据的 DataLoader。

有人可以知道原因或可以解决问题吗?

【问题讨论】:

  • 返回215, 20, None的代码在哪里?顺便说一句,test_dataloader(...) 中没有 return
  • 在更正 returntest_dataloader() 之后,我仍然有问题。
  • a = MNISTData() a.setup() b,c,d = a.train_dataloader(), a.val_dataloader(),a.test_dataloader() 你能试试上面的代码并检查变量吗?

标签: pytorch dataloader pytorch-lightning


【解决方案1】:

正如我在 cmets 中所说,以及 Ivan 在他的回答中发布的,缺少 return 声明:

def test_dataloader(self):
    mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
    return mnist_test  # <<< missing return

根据您的评论,如果我们尝试:

a = MNISTData()
# skip download, assuming you already have it
a.setup()

b, c, d = a.train_dataloader(), a.val_dataloader(), a.test_dataloader()
# len(b)=215, len(c)=20, len(d)=40

我认为您的问题是为什么 b, c, d 的长度与数据集的长度不同。答案是DataLoaderlen()等于批次数,而不是样本数,因此:

import math

batch_size = 256
len(b) = math.ceil(55000 / batch_size) = 215
len(c) = math.ceil(5000 / batch_size) = 20
len(d) = math.ceil(10000 / batch_size) = 40

顺便说一句,我们使用math.ceil,因为DataLoader 默认有drop_last=False,否则它将是math.floor

【讨论】:

    【解决方案2】:

    您的test_dataloader 函数缺少return 语句!

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
        return mnist_test
    

    >>> ds = MNISTData()
    >>> ds.download()
    >>> ds.setup()
    

    然后:

    >>> [len(subset) for subset in \
              (ds.mnist_train, ds.mnist_val, ds.mnist_test)]
    [55000, 5000, 10000]
    
    
    >>> [len(loader) for loader in \
             (ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader())]
    [215, 20, 40]
    

    【讨论】:

      【解决方案3】:

      其他人指出您缺少returntest_dataloader() 的事实当然是正确的。

      从问题的框架来看,您似乎对DatasetDataLoader 的长度感到困惑。

      len(Dataset(..)) 返回数据集中的数据样本数。

      len(DataLoader(ds, ...)) 返回批次数;这取决于您请求了多少batch_size=...,是否要drop_last 批处理等。确切的计算由@Berriel 正确提供

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2021-12-29
        • 2019-11-08
        • 2019-12-30
        • 2017-09-12
        • 2020-12-04
        相关资源
        最近更新 更多