【问题标题】:How to control amount of checkpoint kept by tensorflow estimator?如何控制张量流估计器保留的检查点数量?
【发布时间】:2017-12-29 20:46:49
【问题描述】:

我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从上一个检查点重新开始。不幸的是,它似乎只保留最后 5 个检查点。

你知道如何控制训练期间保留的检查点数量吗?

【问题讨论】:

    标签: tensorflow tensorflow-estimator


    【解决方案1】:

    Tensorflow tf.estimator.Estimatorconfig 作为可选参数,可以是tf.estimator.RunConfig 对象来配置运行时设置。您可以通过以下方式实现:

    # Change maximum number checkpoints to 25
    run_config = tf.estimator.RunConfig()
    run_config = run_config.replace(keep_checkpoint_max=25)
    
    # Build your estimator
    estimator = tf.estimator.Estimator(model_fn,
                                       model_dir=job_dir,
                                       config=run_config,
                                       params=None)
    

    config 参数可用于扩展 estimator.Estimator 的所有类(DNNClassifierDNNLinearCombinedClassifierLinearClassifier 等)。

    【讨论】:

    • 正是我需要的信息,并且 RunConfig 有额外的参数,如 save_checkpoints_secssave_checkpoints_steps,完美!谢谢!
    【解决方案2】:

    作为旁注,我想补充一点,在 TensorfFlow2 中,情况要简单一些。要保留一定数量的检查点文件,您可以修改model_main_tf2.py 源代码。首先,您可以添加并定义一个整数标志为

    # Keep last 25 checkpoints
    flags.DEFINE_integer('checkpoint_max_to_keep', 25,
                         'Integer defining how many checkpoint files to keep.')
    

    然后在调用model_lib_v2.train_loop时使用这个预定义值:

    # Ensure training loop keeps last 25 checkpoints
    model_lib_v2.train_loop(...,
                            checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
                            ...)
    

    上面的符号... 表示model_lib_v2.train_loop 的其他选项。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-04-09
      • 1970-01-01
      • 2019-08-02
      • 2019-02-04
      • 2018-06-16
      • 1970-01-01
      • 2018-11-07
      • 2018-08-30
      相关资源
      最近更新 更多