【问题标题】:Taking subsets of a pytorch dataset获取 pytorch 数据集的子集
【发布时间】:2018-05-06 01:01:05
【问题描述】:

我有一个网络,我想在一些数据集上进行训练(例如,CIFAR10)。我可以通过

创建数据加载器对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题如下:假设我想做几个不同的训练迭代。假设我想首先在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络,依此类推。为此,我需要能够访问这些图像。不幸的是,trainset 似乎不允许此类访问。也就是说,尝试执行trainset[:1000] 或更一般的trainset[mask] 会引发错误。

我可以这样做

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了 trainset.train_data,所以我需要重新定义 trainset)。有什么办法可以避免吗?

理想情况下,我想要一些“等价于”的东西

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)

【问题讨论】:

    标签: python machine-learning neural-network torch pytorch


    【解决方案1】:

    torch.utils.data.Subset 更简单,支持shuffle,并且不需要自己编写采样器:

    import torchvision
    import torch
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=None)
    
    evens = list(range(0, len(trainset), 2))
    odds = list(range(1, len(trainset), 2))
    trainset_1 = torch.utils.data.Subset(trainset, evens)
    trainset_2 = torch.utils.data.Subset(trainset, odds)
    
    trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                                shuffle=True, num_workers=2)
    trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                                shuffle=True, num_workers=2)
    

    【讨论】:

    • 不需要将evensodds 转换为列表——至少在torch 1.5.0 中,Subset 接受生成器:ts1 = Subset(trainset, range(0, len(trainset), 2))
    • 它不允许按类过滤,只能按数据集原始顺序过滤,是吗?
    • @user650654 有点离题,但range 不是生成器。
    • 索引集必须是 python Sequence。即listtuplerange
    【解决方案2】:

    您可以为数据集加载器定义一个自定义采样器,避免重新创建数据集(只需为每个不同的采样创建一个新的加载器)。

    class YourSampler(Sampler):
        def __init__(self, mask):
            self.mask = mask
    
        def __iter__(self):
            return (self.indices[i] for i in torch.nonzero(self.mask))
    
        def __len__(self):
            return len(self.mask)
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    
    sampler1 = YourSampler(your_mask)
    sampler2 = YourSampler(your_other_mask)
    trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              sampler = sampler1, shuffle=False, num_workers=2)
    trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              sampler = sampler2, shuffle=False, num_workers=2)
    

    PS:你可以在这里找到更多信息:http://pytorch.org/docs/master/_modules/torch/utils/data/sampler.html#Sampler

    【讨论】:

    • 谢谢!一个小评论:显然采样器与随机播放不兼容,因此为了达到相同的结果,可以这样做:torch.utils.data.DataLoader(trainset, batch_size=4, sampler=SubsetRandomSampler(np.where(mask)[0 ]),shuffle=False, num_workers=2)
    • 请记住,索引的listsampler 的有效参数,因为它实现了__len____iter__。这种规避了对自定义采样器类的需求。
    猜你喜欢
    • 2021-10-23
    • 2023-02-15
    • 2011-11-20
    • 1970-01-01
    • 2020-06-21
    • 2020-09-26
    • 2021-08-06
    • 2018-03-09
    • 2022-07-18
    相关资源
    最近更新 更多