【问题标题】:What needs to be saved to reuse a model in TensorFlow在 TensorFlow 中重用模型需要保存什么
【发布时间】:2017-02-22 11:02:45
【问题描述】:

我已经开始探索 TensorFlow 库并尝试使用 MNIST 数据的图像分类example。我希望在训练阶段结束后将模型存储在一个文件中,以便我可以在需要时使用它。我检查了this link,它讲述了如何将值从 TensorFlow 保存到任何文件并读取它。到目前为止,我可以按照链接中的建议使用 pickle 将脚本中的一些变量保存到文件中。但是,我无法掌握需要保存在文件中以存储模型的当前状态以供以后使用的内容。 请有人可以通过存储模型和加载该模型的示例来解释该部分。

【问题讨论】:

  • 我在 git 上分享了一个示例 python 脚本,它训练算法并将训练的数据保存在指定位置,另一个脚本加载相同的数据并根据训练的数据进行评估。你可以在这里找到相同的:github.com/asonipsl/utility-projects/tree/master/…

标签: machine-learning tensorflow


【解决方案1】:

要在 Tensorflow 中保存和恢复变量,需要以下内容。

1) 要保存和恢复的变量列表 2) tf.train.Saver

一般来说,1)是通过

实现的
# To save and restore whole tf variables
all_vars = tf.global_variables()

或者,

# To save and restore the specific tf variables using scope
all_vars = tf.global_variables()
model_vars = [k for k in all_vars if k.name.startswith("xxx")]
# "xxx" is the expected scope

那么,2) 是由

实现的
saver = tf.train.Saver(vars_list)
# vars_list is list of variables from above

最后,保存变量,(使用名为 'sess' 的 tf.Session() 运行)

saver.save(sess, '/directory/to/chechpoint/file.ckpt')

并恢复它们,

saver.restore(sess, '/directory/to/chechpoint/file.ckpt')

【讨论】:

    【解决方案2】:

    只有Variables 可以保存和恢复。当您需要重用保存的变量时,您需要首先通过创建神经网络并设置 NN 的参数(如层数、学习率和 dropout 等)来构建图形。从检查点恢复的唯一值是训练中定义的变量过程。您可以查看任何示例,例如this one

    综上所述,只有Variables可以并且需要保存和恢复,神经网络配置,placeholders不能。

    【讨论】:

      【解决方案3】:

      首先,你应该看看这个other question

      TensorFlow 实现了用于管理保存和恢复检查点的方法,特别是tf.train.saver 类。查看官方文档here。检查点基本上将张量的值(以及其他内容)存储在磁盘中。

      引用文档:

      检查点是专有格式的二进制文件,将变量名称映射到张量值。检查检查点内容的最佳方法是使用Saver 加载它。

      【讨论】:

        猜你喜欢
        • 2019-05-20
        • 1970-01-01
        • 2017-01-15
        • 1970-01-01
        • 2018-11-26
        • 2019-08-30
        • 1970-01-01
        • 2017-05-07
        • 1970-01-01
        相关资源
        最近更新 更多