【发布时间】:2021-11-09 05:29:23
【问题描述】:
我正在尝试使用 pytorch-lightening 训练具有 MNIST 数据集的分类器。
import pytorch_lightning as pl
from torchvision import transforms
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader, random_split
class MNISTData(pl.LightningDataModule):
def __init__(self, data_dir='./', batch_size=256):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.ToTensor()
def download(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
使用MNISTData().setup()后,我得到了MNISTData().mnist_train, MNISTData().mnist_val, MNISTData().mnist_test,长度分别为55000、5000、10000,类型为torch.utils.data.dataset.Subset。
但是当我调用 dataloader w.r.t MNISTData().train_dataloader, MNISTData().val_dataloader, MNISTData().test_dataloader 时,我只得到包含 215、20、None 数据的 DataLoader。
有人可以知道原因或可以解决问题吗?
【问题讨论】:
-
返回
215, 20, None的代码在哪里?顺便说一句,test_dataloader(...)中没有return。 -
在更正
return的test_dataloader()之后,我仍然有问题。 -
a = MNISTData()a.setup()b,c,d = a.train_dataloader(), a.val_dataloader(),a.test_dataloader()你能试试上面的代码并检查变量吗?
标签: pytorch dataloader pytorch-lightning