【问题标题】:Adding custom labels to pytorch dataloader/dataset does not work for custom dataset将自定义标签添加到 pytorch 数据加载器/数据集不适用于自定义数据集
【发布时间】:2019-11-08 12:06:16
【问题描述】:

我正在参加 Kaggle 上的仙人掌图像比赛,我正在尝试将 PyTorch 数据加载器用于我的 CNN。但是,我遇到了无法为训练集设置标签的问题。训练集图像在文件夹中,标签在 csv 文件中。这是我的代码。

 train = torchvision.datasets.ImageFolder(root='../input/train', 
 transform=transform)

 train.targets = torch.from_numpy(df['has_cactus'].values)

 train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)

 for i, data in enumerate(train_loader, 0):
     print(data[1])

此代码输出全为零的批处理张量,这显然是不正确的,因为绝大多数标签(如果您要查看数据框)都是标签。我相信这是将标签分配给“train.targets”的问题。如果在分配其他标签之前打印“train.targets”,它会返回一个全为零的张量,这与我得到的不正确结果一致。我该如何解决这个问题?

【问题讨论】:

    标签: python machine-learning computer-vision pytorch


    【解决方案1】:

    我通常继承内置的DataSet类如下:

    from torch.utils.data import DataLoader
    class DataSet:
    
        def __init__(self, root):
            """Init function should not do any heavy lifting, but
                must initialize how many items are available in this data set.
            """
    
            self.ROOT = root
            self.images = read_images(root + "/images")
            self.labels = read_labels(root + "/labels")
    
        def __len__(self):
            """return number of points in our dataset"""
    
            return len(self.images)
    
        def __getitem__(self, idx):
            """ Here we have to return the item requested by `idx`
                The PyTorch DataLoader class will use this method to make an iterable for
                our training or validation loop.
            """
    
            img = images[idx]
            label = labels[idx]
    
            return img, label
    

    现在,你可以创建这个类的一个实例,

    ds = Dataset('../input/train')
    

    现在,您可以实例化 DataLoader:

    dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
    

    这将创建您可以访问的批次数据:

    for image, label in dl:
        print(label)
    

    【讨论】:

    • 感谢您的回复。您如何建议我在代码中实现“read_images”方法?
    • @LeoStepanewk 这是另一个问题吗?
    【解决方案2】:

    您可以通过继承 @Sai Krishnan 提到的内置 Dataset 类来制作自定义数据集加载器。

    from collections import Counter
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    import argparse
    import torch
    from torch.utils.data import Dataset
    from tqdm import tqdm
    from PIL import Image
    
    VOC_CLASSES = ('background',  # always index 0
                   'aeroplane', 'bicycle', 'bird', 'boat',
                   'bottle', 'bus', 'car', 'cat', 'chair',
                   'cow', 'diningtable', 'dog', 'horse',
                   'motorbike', 'person', 'pottedplant',
                   'sheep', 'sofa', 'train', 'tvmonitor')
    
    NUM_CLASSES = len(VOC_CLASSES) + 1
    
    class customDataset(Dataset):
        """Pascal VOC 2007 Dataset"""
        def __init__(self, list_file, img_dir, mask_dir, transform=None):
            # list of images to load in a .txt file
            self.images = open(list_file, "rt").read().split("\n")[:-1]
            self.transform = transform
            # note that in the .txt file the image names are stored without the extension(.jpg or .png)
            self.img_extension = ".jpg"
            self.mask_extension = ".png"
    
            self.image_root_dir = img_dir
            self.mask_root_dir = mask_dir
            # can comment the line below
            self.counts = self.__compute_class_probability()
    
        def __len__(self):
            return len(self.images)
    
        def __getitem__(self, index):
            name = self.images[index]
            image_path = os.path.join(self.image_root_dir, name + self.img_extension)
            mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
    
            image = self.load_image(path=image_path)
            gt_mask = self.load_mask(path=mask_path)
            data = {
                        'image': torch.FloatTensor(image),
                        'mask' : torch.LongTensor(gt_mask)
                        }
            return data
    
        def __compute_class_probability(self):
            counts = dict((i, 0) for i in range(NUM_CLASSES))
    
            for name in self.images:
                mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
    
                raw_image = Image.open(mask_path).resize((224, 224))
                imx_t = np.array(raw_image).reshape(224*224)
                imx_t[imx_t==255] = len(VOC_CLASSES)
    
                for i in range(NUM_CLASSES):
                    counts[i] += np.sum(imx_t == i)
            return counts
    
        def get_class_probability(self):
            values = np.array(list(self.counts.values()))
            p_values = values/np.sum(values)
            return torch.Tensor(p_values)
    
        def load_image(self, path=None):
            # can use any other library too like OpenCV as long as you are consistent with it
            raw_image = Image.open(path)
            raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
            imx_t = np.array(raw_image, dtype=np.float32)/255.0
    
            return imx_t
        # can comment the below function if not needed
        def load_mask(self, path=None):
            raw_image = Image.open(path)
            raw_image = raw_image.resize((224, 224))
            imx_t = np.array(raw_image)
            imx_t[imx_t==255] = len(VOC_CLASSES)
            return imx_t
    
    

    一旦类准备就绪,您就可以创建它的实例并使用它。

    data_root = os.path.join("VOCdevkit", "VOC2007")
    list_file_path = os.path.join(data_root, "ImageSets", "Segmentation", "train.txt")
    img_dir = os.path.join(data_root, "JPEGImages")
    mask_dir = os.path.join(data_root, "SegmentationClass")
    
    
    objects_dataset = customDataset(list_file=list_file_path,
                                            img_dir=img_dir,
                                            mask_dir=mask_dir)
    sample = objects_dataset[k]
    image, mask = sample['image'], sample['mask']
    image.transpose_(0, 2)
    
    fig = plt.figure()
    
    a = fig.add_subplot(1,2,1)
    plt.imshow(image)
    
    a = fig.add_subplot(1,2,2)
    plt.imshow(mask)
    
    plt.show()
    
    

    请确保正确插入文件路径。此外,您还必须在 customDataset() 类中正确加载标签。

    注意:这个 sn-p 只是一个自定义数据加载器应该如何的示例。您必须对其进行适当的更改以使其适用于您的情况。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-11-09
      • 2019-12-30
      • 2014-10-14
      • 2017-09-12
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多