【发布时间】:2022-01-23 07:02:06
【问题描述】:
我的数据集由从原始图像(人脸补丁和随机人脸补丁之外)获得的图像补丁组成。补丁存储在一个文件夹中,该文件夹具有补丁源自的原始图像的名称。我创建了自己的 DataSet 和 DataLoader,但是当我遍历数据集时,数据不会批量返回。大小为 1 的批次应该包括一个补丁元组数组和一个标签,因此随着批次大小的增加,我们应该得到一个带有标签的元组数组。但是无论批量大小,DataLoader 都只返回一个元组数组。
我的数据集:
import os
import cv2 as cv
import PIL.Image as Image
import torchvision.transforms as Transforms
from torch.utils.data import dataset
class PatchDataset(dataset.Dataset):
def __init__(self, img_folder, n_patches):
self.img_folder = img_folder
self.n_patches = n_patches
self.img_names = sorted(os.listdir(img_folder))
self.transform = Transforms.Compose([
Transforms.Resize((50, 50)),
Transforms.ToTensor()
])
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = self.img_names[idx]
patch_dir = os.path.join(self.img_folder, img_name)
patches = []
for i in range(self.n_patches):
face_patch = cv.imread(os.path.join(patch_dir, f'{str(i)}_face.png'))
face_patch = cv.cvtColor(face_patch, cv.COLOR_BGR2RGB)
face_patch = Image.fromarray(face_patch)
face_patch = self.transform(face_patch)
patch = cv.imread(os.path.join(patch_dir, f'{str(i)}_patch.png'))
patch = cv.cvtColor(patch, cv.COLOR_BGR2RGB)
patch = Image.fromarray(patch)
patch = self.transform(patch)
patches.append((face_patch, patch))
return patches, int(img_name.split('-')[0])
然后我就这样使用它:
X = PatchDataset(PATCHES_DIR, 9)
train_dl = dataloader.DataLoader(
X,
batch_size=10,
drop_last=True
)
for batch_X, batch_Y in train_dl:
print(len(batch_X))
print(len(batch_Y))
在此提供的情况下,批量大小为 10,因此打印 batch_Y 会返回正确的数字 (10)。但是batch_X 的打印返回 9,这是补丁对的数量 - 从数据集中只返回一个样本,而不是 10 个样本的批次,每个样本的长度为 9。
【问题讨论】:
-
通常你必须从
torch.utils.data.Dataset继承,但你使用dataset.Dataset我不知道。或者你做了:import torch.utils.data as dataset?如果这不是错误,请提供您使用数据加载器的代码:) -
@TheodorPeifer 是的,我是这样导入的,
DataLoader也是如此。我添加了您要求的示例并提供了更多信息。
标签: python pytorch dataset pytorch-dataloader