【问题标题】:Stop and Restart Training on VGG-16在 VGG-16 上停止和重新开始训练
【发布时间】:2019-01-31 04:30:23
【问题描述】:

我正在使用预训练的 VGG-16 模型进行图像分类。我正在添加自定义最后一层,因为我的分类类的数量是 10。我正在训练模型 200 个时期。

我的问题是:如果我在某个时期随机停止(通过关闭 python 窗口)训练,有什么方法可以说没有。 50 和从那里恢复?我已阅读有关保存和重新加载模型的信息,但我的理解是这仅适用于我们的自定义模型,而不适用于 VGG-16 等预训练模型。

【问题讨论】:

    标签: python-3.x machine-learning keras checkpointing vgg-net


    【解决方案1】:

    这是ModelCheckpoint 的自定义版本,我用它来从给定的时期恢复训练,gist。它会将 epoch 和其他日志保存到相应的 JSON 文件中,它还会在开始时检查是否恢复训练。您需要调用get_last_epoch 并在model.fit 中设置initial_epoch 才能从那个时期恢复。

    import json
    
    class StatefulCheckpoint(ModelCheckpoint):
      """Save extra checkpoint data to resume training."""
      def __init__(self, weight_file, state_file=None, **kwargs):
        """Save the state (epoch etc.) along side weights."""
        super().__init__(weight_file, **kwargs)
        self.state_f = state_file
        self.state = dict()
        if self.state_f:
          # Load the last state if any
          try:
            with open(self.state_f, 'r') as f:
              self.state = json.load(f)
            self.best = self.state['best']
          except Exception as e: # pylint: disable=broad-except
            print("Skipping last state:", e)
    
      def on_train_begin(self, logs=None):
        prefix = "Resuming" if self.state else "Starting"
        print("{} training...".format(prefix))
    
      def on_epoch_end(self, epoch, logs=None):
        """Saves training state as well as weights."""
        super().on_epoch_end(epoch, logs)
        if self.state_f:
          state = {'epoch': epoch+1, 'best': self.best}
          state.update(logs)
          state.update(self.params)
          with open(self.state_f, 'w') as f:
            json.dump(state, f)
    
      def get_last_epoch(self, initial_epoch=0):
        """Return last saved epoch if any, or return default argument."""
        return self.state.get('epoch', initial_epoch)
    

    【讨论】:

      【解决方案2】:

      您可以使用ModelCheckpoint 回调定期保存您的模型。要使用它,请将callbacks 参数传递给fit 方法:

      from keras.callbacks import ModelCheckpoint
      checkpointer = ModelCheckpoint(filepath='model-{epoch:02d}.hdf5', ...)
      model.fit(..., callbacks=[checkpointer])
      

      然后,稍后您可以加载最后保存的模型。有关此回调的更多自定义,请查看文档。

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2017-09-28
        • 2019-08-01
        • 2018-08-03
        • 1970-01-01
        • 2021-01-31
        • 1970-01-01
        • 2020-02-23
        • 2016-11-16
        相关资源
        最近更新 更多