【问题标题】:Continue training on SavedModel or load checkpoint from SavedModel继续在 SavedModel 上训练或从 SavedModel 加载检查点
【发布时间】:2019-10-07 07:37:35
【问题描述】:

在 tensorflow 1.14 中,很明显 tf.compat.v1.train.init_from_checkpoint 可以加载 ckpt 以继续训练(或热启动)。但是,我在SavedModel 中找不到任何对应的方法,而tf.estimator.WarmStartSetting 也只支持ckpt。这对我来说很奇怪,因为this answer 提到应该有一个检查点存储在SavedModel 中。有谁知道:

  1. 如何在 SavedModel 中加载检查点?或
  2. 如何在 SavedModel 上热启动训练?

【问题讨论】:

  • 如果您有模型代码,最好使用检查点进行保存和恢复。``` tensorflow.org/guide/checkpoint``` 。它甚至与 tf1.0 兼容,可以加载使用saver.save()保存的旧检查点
  • 谢谢@SarathRNair。我知道检查站有效。但是,我想在需要 SavedModel 格式的 tf-serving 上部署我的模型。这就是为什么我想知道如何从 SavedModel 热启动,因为我不想保存和加载模型两次(检查点和 SavedModel)。
  • 你有解决办法吗?
  • 嗨@SarathRNair,不幸的是,我还没有找到解决方案。我也在 TF repo github.com/tensorflow/tensorflow/issues/33162 中问过这个问题,我希望他们会有这个新功能,或者有人可以有其他简单的解决方案。

标签: tensorflow tensorflow-serving tensorflow-estimator


【解决方案1】:

为了加载 SavedModel 继续训练,您可以使用 tf.saved_model.loader.load 如下:

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_location)

为了提供新的输入数据,您可以获得输入张量名称,如下所示:

signature_def = meta_graph_def.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs = [v.name for v in signature_def.inputs.values()]
input_tensors = [node.split(":")[0] for node in inputs]

然后你可以制作一些feed_dict 来为输入张量提供新的输入。获取输出张量的方法与我上面概述的方法类似。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2017-10-29
    • 2019-12-18
    • 2020-05-22
    • 1970-01-01
    • 1970-01-01
    • 2020-12-13
    • 1970-01-01
    • 2021-01-07
    相关资源
    最近更新 更多