您可以通过继承 @Sai Krishnan 提到的内置 Dataset 类来制作自定义数据集加载器。
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from PIL import Image
VOC_CLASSES = ('background', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NUM_CLASSES = len(VOC_CLASSES) + 1
class customDataset(Dataset):
"""Pascal VOC 2007 Dataset"""
def __init__(self, list_file, img_dir, mask_dir, transform=None):
# list of images to load in a .txt file
self.images = open(list_file, "rt").read().split("\n")[:-1]
self.transform = transform
# note that in the .txt file the image names are stored without the extension(.jpg or .png)
self.img_extension = ".jpg"
self.mask_extension = ".png"
self.image_root_dir = img_dir
self.mask_root_dir = mask_dir
# can comment the line below
self.counts = self.__compute_class_probability()
def __len__(self):
return len(self.images)
def __getitem__(self, index):
name = self.images[index]
image_path = os.path.join(self.image_root_dir, name + self.img_extension)
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
image = self.load_image(path=image_path)
gt_mask = self.load_mask(path=mask_path)
data = {
'image': torch.FloatTensor(image),
'mask' : torch.LongTensor(gt_mask)
}
return data
def __compute_class_probability(self):
counts = dict((i, 0) for i in range(NUM_CLASSES))
for name in self.images:
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
raw_image = Image.open(mask_path).resize((224, 224))
imx_t = np.array(raw_image).reshape(224*224)
imx_t[imx_t==255] = len(VOC_CLASSES)
for i in range(NUM_CLASSES):
counts[i] += np.sum(imx_t == i)
return counts
def get_class_probability(self):
values = np.array(list(self.counts.values()))
p_values = values/np.sum(values)
return torch.Tensor(p_values)
def load_image(self, path=None):
# can use any other library too like OpenCV as long as you are consistent with it
raw_image = Image.open(path)
raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
imx_t = np.array(raw_image, dtype=np.float32)/255.0
return imx_t
# can comment the below function if not needed
def load_mask(self, path=None):
raw_image = Image.open(path)
raw_image = raw_image.resize((224, 224))
imx_t = np.array(raw_image)
imx_t[imx_t==255] = len(VOC_CLASSES)
return imx_t
一旦类准备就绪,您就可以创建它的实例并使用它。
data_root = os.path.join("VOCdevkit", "VOC2007")
list_file_path = os.path.join(data_root, "ImageSets", "Segmentation", "train.txt")
img_dir = os.path.join(data_root, "JPEGImages")
mask_dir = os.path.join(data_root, "SegmentationClass")
objects_dataset = customDataset(list_file=list_file_path,
img_dir=img_dir,
mask_dir=mask_dir)
sample = objects_dataset[k]
image, mask = sample['image'], sample['mask']
image.transpose_(0, 2)
fig = plt.figure()
a = fig.add_subplot(1,2,1)
plt.imshow(image)
a = fig.add_subplot(1,2,2)
plt.imshow(mask)
plt.show()
请确保正确插入文件路径。此外,您还必须在 customDataset() 类中正确加载标签。
注意:这个 sn-p 只是一个自定义数据加载器应该如何的示例。您必须对其进行适当的更改以使其适用于您的情况。