【问题标题】:How can I save PyTorch's DataLoader instance?如何保存 PyTorch DataLoader 实例?
【发布时间】:2020-07-14 13:09:21
【问题描述】:

我想保存 PyTorch 的 torch.utils.data.dataloader.DataLoader 实例,这样我就可以从中断的地方继续训练(保留随机播放种子、状态和所有内容)。

【问题讨论】:

  • @Meh 谢谢,但 torch.save() 不保存状态。如果我保存并再次加载它,它将从一个新的随机种子开始。
  • 只需使用相同的 shuffle 种子就足以从上一个 epoch 重新开始训练。我不认为你可以在一个时代之间重新开始。

标签: pytorch


【解决方案1】:

您需要采样器的自定义实现。 可以通过以下方式轻松使用:https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5

您可以像这样保存和恢复:

sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
    print(x)
    break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
    print(x)

【讨论】:

    【解决方案2】:

    这很简单。应该设计自己的Sampler,它采用起始索引并自行打乱数据:

    import random
    from torch.utils.data.dataloader import Sampler
    
    
    random.seed(224)  # use a fixed number
    
    
    class MySampler(Sampler):
        def __init__(self, data, i=0):
            random.shuffle(data)
            self.seq = list(range(len(data)))[i * batch_size:]
    
        def __iter__(self):
            return iter(self.seq)
    
        def __len__(self):
            return len(self.seq)
    

    现在将最后一个索引 i 保存在某处,下次使用它实例化 DataLoader

    train_dataset = MyDataset(train_data)
    train_sampler = MySampler(train_dataset, last_i)
    train_data_loader = DataLoader(dataset=train_dataset,                                                         
                                   batch_size=batch_size, 
                                   sampler=train_sampler,
                                   shuffle=False)  # don't forget to set DataLoader's shuffle to False
    

    在 Colab 上训练时非常有用。

    【讨论】:

    • 我相信你需要在洗牌后重新选择索引。否则你将不知道你已经覆盖了哪些索引。
    • @fsociety 不在此示例中。在这里,我使用一个恒定的随机种子,并且我不会在每个时期都对数据进行洗牌。所以,只知道最后一个索引就足够了。但是,如果要在训练时洗牌(或没有固定的随机种子),他们应该按照您的建议保留所有涵盖的索引。
    • 我看不出有什么不同。如果你想在一个 epoch 期间恢复你的采样器,这是​​这里的用例,你不知道你从哪里停下来。因此,假设您保存i,假设为 10。然后您只取所有大于该值的索引,将其随机排列,然后选择接下来的 10 个。但您不知道那些是哪 10 个。很容易修复,用固定种子改组后进行子选择。
    • @fsociety 如代码所示,shuffle 只进行一次,即在 Sampler 被实例化的时候。由于随机种子是相同的数字,因此每次对data 进行洗牌时,我们都会得到相同的索引顺序(这就是固定随机种子的含义)。在这种情况下,i 始终表示相同的索引。因此,总是在选择 ith 元素之前完成洗牌。
    • @fsociety 抱歉。我很久以前写了这个答案。我修复了代码。感谢您指出错误。
    【解决方案3】:

    对此的原生 PyTorch 支持仍然不可用,but considered for future improvements。不过,请参阅自定义构建的其他答案。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-07-16
      • 2022-01-24
      • 1970-01-01
      • 2019-05-03
      • 2021-04-18
      • 2019-07-20
      • 1970-01-01
      • 2021-02-15
      相关资源
      最近更新 更多