【问题标题】:TensorFlow Eager Mode: How to restore a model from a checkpoint?TensorFlow Eager Mode:如何从检查点恢复模型?
【发布时间】:2017-12-17 05:39:21
【问题描述】:

我已经在 TensorFlow Eager 模式下训练了一个 CNN 模型。现在我正在尝试从检查点文件中恢复经过训练的模型,但没有成功。

我发现的所有示例(如下所示)都在谈论将检查点恢复到会话。但我需要将模型恢复为渴望模式,即不创建会话。

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

基本上我需要的是这样的:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)

然后我可以使用模型进行预测。

有人可以帮忙吗?

更新

示例代码见:mnist eager mode demo

我尝试按照@Jay Shah 的回答中的步骤进行操作,几乎可以正常工作,但恢复后的模型中没有任何变量。

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]

原始模型中有很多变量。:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...

【问题讨论】:

  • 为什么检查点名称不同???保存检查点路径与您正在恢复的路径不同...输出对我来说似乎很奇怪
  • 如果我使用相同的检查点名称,它将不起作用。 'my_model.ckpt-1720' 是 save_network_checkpoint 函数返回的名称。根据文档,这应该是用于恢复模型的名称。
  • 哦..是的,返回的值必须去...只是确保您以正确的方式进行操作
  • 嘿@Allen 你可以试试这个来保存模型tf.contrib.eager.Saver([variable_list]).save(chkpt_file),它会返回一个字符串,所以在恢复时使用该字符串如下:tf.contrib.eager.Saver.restore(returned_string) 这个东西模仿tf.train.Saver,但没有会话为此需要...但是必须启用急切模式...在制作 saver 对象时,您必须提供要存储的变量列表...。您可以从我的答案中获取变量列表.. . 变量必须是tfe.Variable

标签: python tensorflow deep-learning


【解决方案1】:

Eager Execution 仍然是 TensorFlow 中的一项新功能,并且未包含在最新版本中,因此并非所有功能都受支持,但幸运的是,从保存的检查点加载模型是。

您需要使用 tfe.Saver 类(它是 tf.train.Saver 类的薄包装),并且您的代码应如下所示:

saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')

其中 [x,y] 表示您希望恢复的变量和/或模型的列表。这应该与最初创建创建检查点的保护程序时传递的变量精确匹配。

更多细节,包括示例代码,可以找到here,saver的API细节可以找到here

【讨论】:

  • 特别是tfe.restore_variables_on_create,如果您的变量在您想要恢复时尚未创建,则非常有用。这也用于eager mnist example
  • 感谢@mr_snuffles 的回答。您展示了如何恢复变量。您能否解释一下如何在 mnist Eager 模式教程github.com/tensorflow/tensorflow/blob/… 中恢复模型?
  • @Allen 您使用 saver.save 保存会话,然后调用该保护程序从最新的检查点恢复模型。它应该看起来像:sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./'))更多关于保存和加载模型的细节可以找到here
  • 因此,在 2018 年 9 月,Eager 模式似乎仍然无法加载图表和检查点。会话图编码仍然是必需的,这使得研究人员使用了两种不同的 TF API。伙计们,这并不比以前简单。 TF比以前差了。
【解决方案2】:

好的,在以逐行模式运行代码几个小时后,我找到了一种将检查点恢复到新的 TensorFlow Eager Mode 模型的方法。

使用来自TF Eager Mode MNIST的示例

步骤:

  1. 你的模型训练完成后,从训练过程中创建的checkpoint文件夹中找到最新的checkpoint(或者你想要的checkpoint)索引文件,比如'ckpt-25800.index'。在步骤 5 中恢复时仅使用文件名“ckpt-25800”。

  2. 启动一个新的python终端并通过运行启用TensorFlow Eager模式:

    tfe.enable_eager_execution()

  3. 创建一个新的 MNISTMOdel 实例:

    model_new = MNISTModel()

  4. 通过运行一次虚拟训练过程来初始化model_new的变量。(此步骤很重要。如果不先初始化变量,则无法通过以下步骤恢复它们。但是我找不到另一种初始化变量的方法在 Eager 模式下,而不是我在下面所做的。)

    model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)

  5. 使用步骤 1 中确定的检查点将变量恢复到 model_new。

    tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')

  6. 如果恢复过程成功,您应该会看到如下内容:

    INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800

现在检查点已成功恢复到model_new,您可以使用它对新数据进行预测。

【讨论】:

  • 这行得通,但运行一个假前传球至少可以这么说看起来很糟糕:\不是你的错,而是 TF 的
  • 我们如何使用 TFE API 从磁盘完全加载模型,并填充其所有节点、图形、变量、超参数等?不要使用 MNISTModel() 代码重复模型创建代码。不要使用图形 API 和会话。仅使用 TFE API。 TF 人员仍然没有提供这种基本需求的工作示例。 Smartalecs 到处都只是提供文档的链接,就好像这值得一分钱一样。
【解决方案3】:

我喜欢分享TFLearn 库,即Deep learning library featuring a higher-level API for TensorFlow。借助这个库,您可以轻松地save and restore 一个模型。

保存模型

model = tflearn.DNN(net) #Here 'net' is your designed network model. 
#This is a sample example for training the model
model.fit(train_x, train_y, n_epoch=10, validation_set=(test_x, test_y), batch_size=10, show_metric=True)
model.save("model_name.ckpt")

恢复模型

model = tflearn.DNN(net)
model.load("model_name.ckpt")

有关tflearn 的更多示例,您可以查看一些网站,例如...

【讨论】:

  • 感谢您的回答,很高兴知道它可以使用 TFlearn 完成。但是,我仍然想在 TensorFlow 中找到一种方法。
【解决方案4】:
  • 首先,通过执行以下操作将模型保存在检查点中:

saver.save(sess, './my_model.ckpt')

  • 在上面一行中,您将会话保存在“my_model.ckpt”检查点中

以下代码恢复模型

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './my_model.ckpt')
  • 当您将会话恢复为模型时,您将从 ckpt 恢复您的模型

在急切模式下保存:

tf.contrib.eager.save_network_checkpoint(sess,'./my_model.ckpt')

要恢复急切模式:

tf.contrib.eager.restore_network_checkpoint(sess,'./my_model.ckpt')

sess 是 Network 类的对象。网络类的任何对象都可以保存和恢复。网络对象的快速解释:-

class TwoLayerNetwork(tfe.Network):
    def __init__(self, name):
        super(TwoLayerNetwork, self).__init__(name=name)
        self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,)))
        self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,)))
    def call(self, inputs):
        return self.layer_two(self.layer_one(inputs))

在构造对象并调用Network后,变量列表 由跟踪的Layers 创建可通过Network.variables 获得: 蟒蛇

  sess = TwoLayerNetwork(name="net")   # sess is object of Network 
  output = sess(tf.ones([1, 8]))
  print([v.name for v in sess.variables])
  ```
  =================================================================
  This example prints variable names, one kernel and one bias per
  `tf.layers.Dense` layer:

  ['net/dense/kernel:0',
   'net/dense/bias:0',
   'net/dense_1/kernel:0',
   'net/dense_1/bias:0']

  These variables can be passed to a `Saver` (`tf.train.Saver`, or
  `tf.contrib.eager.Saver` when executing eagerly) to save or restore the
  `Network` 
  =================================================================
  ```
  tfe.save_network_checkpoint(sess,'./my_model.ckpt') # saving the model
  tfe.restore_network_checkpoint(sess,'./my_model.ckpt') # restoring 

【讨论】:

  • 谢谢@jay Shah。我希望模型以 Eager 模式恢复。
  • @Allen 我已根据您的要求编辑了我的答案...检查一下
  • 感谢您的更新。这几乎对我有用。我已经能够保存经过训练的模型,但是当我尝试恢复它时,恢复的模型中没有任何变量。请查看我的问题中的更新。
  • 要求是: 1 - 需要急切模式而不是会话代码。 2 - 它必须工作。
【解决方案5】:

使用tfe.Saver().save() 保存变量:

for epoch in range(epochs):
    train_and_optimize()
    all_variables = model.variables + optimizer.variables()

    # save the varibles 
    tfe.Saver(all_variables).save(checkpoint_prefix)

然后使用 tfe.Saver().restore() 重新加载保存的变量:

tfe.Saver((model.variables + optimizer.variables())).restore(checkpoint_prefix)

然后使用保存的变量加载模型,无需像@Stefan Falk 的回答那样创建新变量。

【讨论】:

    猜你喜欢
    • 2017-11-27
    • 2018-02-16
    • 2020-08-11
    • 1970-01-01
    • 2016-06-14
    • 2017-08-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多