【发布时间】:2017-12-17 05:39:21
【问题描述】:
我已经在 TensorFlow Eager 模式下训练了一个 CNN 模型。现在我正在尝试从检查点文件中恢复经过训练的模型,但没有成功。
我发现的所有示例(如下所示)都在谈论将检查点恢复到会话。但我需要将模型恢复为渴望模式,即不创建会话。
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
基本上我需要的是这样的:
tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)
然后我可以使用模型进行预测。
有人可以帮忙吗?
更新
示例代码见:mnist eager mode demo
我尝试按照@Jay Shah 的回答中的步骤进行操作,几乎可以正常工作,但恢复后的模型中没有任何变量。
tfe.save_network_checkpoint(model,'./test/my_model.ckpt')
Out[58]:
'./test/my_model.ckpt-1720'
model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables
Out[72]:
[]
原始模型中有很多变量。:
model.variables
[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
array([[[[ -8.25184360e-02, 6.77833706e-03, 6.97569922e-02,...
【问题讨论】:
-
为什么检查点名称不同???保存检查点路径与您正在恢复的路径不同...输出对我来说似乎很奇怪
-
如果我使用相同的检查点名称,它将不起作用。 'my_model.ckpt-1720' 是 save_network_checkpoint 函数返回的名称。根据文档,这应该是用于恢复模型的名称。
-
哦..是的,返回的值必须去...只是确保您以正确的方式进行操作
-
嘿@Allen 你可以试试这个来保存模型
tf.contrib.eager.Saver([variable_list]).save(chkpt_file),它会返回一个字符串,所以在恢复时使用该字符串如下:tf.contrib.eager.Saver.restore(returned_string)这个东西模仿tf.train.Saver,但没有会话为此需要...但是必须启用急切模式...在制作 saver 对象时,您必须提供要存储的变量列表...。您可以从我的答案中获取变量列表.. . 变量必须是tfe.Variable
标签: python tensorflow deep-learning