【发布时间】:2017-12-29 20:46:49
【问题描述】:
我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从上一个检查点重新开始。不幸的是,它似乎只保留最后 5 个检查点。
你知道如何控制训练期间保留的检查点数量吗?
【问题讨论】:
标签: tensorflow tensorflow-estimator
我注意到新的 Estimator API 会在训练期间自动保存检查点,并在训练中断时自动从上一个检查点重新开始。不幸的是,它似乎只保留最后 5 个检查点。
你知道如何控制训练期间保留的检查点数量吗?
【问题讨论】:
标签: tensorflow tensorflow-estimator
Tensorflow tf.estimator.Estimator 将config 作为可选参数,可以是tf.estimator.RunConfig 对象来配置运行时设置。您可以通过以下方式实现:
# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)
# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
model_dir=job_dir,
config=run_config,
params=None)
config 参数可用于扩展 estimator.Estimator 的所有类(DNNClassifier、DNNLinearCombinedClassifier、LinearClassifier 等)。
【讨论】:
save_checkpoints_secs 和 save_checkpoints_steps,完美!谢谢!
作为旁注,我想补充一点,在 TensorfFlow2 中,情况要简单一些。要保留一定数量的检查点文件,您可以修改model_main_tf2.py 源代码。首先,您可以添加并定义一个整数标志为
# Keep last 25 checkpoints
flags.DEFINE_integer('checkpoint_max_to_keep', 25,
'Integer defining how many checkpoint files to keep.')
然后在调用model_lib_v2.train_loop时使用这个预定义值:
# Ensure training loop keeps last 25 checkpoints
model_lib_v2.train_loop(...,
checkpoint_max_to_keep=FLAGS.checkpoint_max_to_keep,
...)
上面的符号... 表示model_lib_v2.train_loop 的其他选项。
【讨论】: