【问题标题】:How to restore a tensorflow model?如何恢复张量流模型?
【发布时间】:2016-07-18 03:46:29
【问题描述】:
我正在尝试使用.ckpt 文件恢复模型,我通过在tensorflow/models/embedding 中运行word2vec_optimized.py 得到该文件。我不确定如何恢复变量以便我可以加载模型并使用它,因为所有 tf 变量都封装在 tensorflow/models/embedding/word2vec_optimized.py 的类中并初始化。任何帮助,将不胜感激。
此外,如果我“恢复”创建的 .ckpt,我现在是否有一个 Wor2Vec 实例,或者当我使用 .ckpt 恢复模型时实际得到什么?
【问题讨论】:
标签:
python
tensorflow
word2vec
【解决方案1】:
当您在保护程序上调用 save 函数时,您会将用于训练模型的 tf.Session 传递给它。这包含对包含所有变量的图形的引用。不要将 python 变量与 tensorflow 变量混淆。即使您在 python 中不再有指向您创建的 tensorflow 变量的变量,如果它是计算图的一部分,它仍然存在。创建模型后,尝试运行以下代码。
for v in tf.all_variables():
print(v.name)
这将打印出您创建的每个变量的名称。默认情况下,保护程序将保存所有这些。只要变量在您恢复它们时具有相同的名称,它们的创建位置就无关紧要。只需确保在将所有变量添加到模型后进行恢复即可。当您为变量提供初始化程序时,初始化仅在您调用 sess.run(tf.initialize_all_variables()) 时运行。如果您只是恢复值,则不需要调用它。我经常使用下面的代码。
sess = tf.Session()
saver = tf.train.Saver()
if 'restore' in sys.argv:
saver.restore(sess, '/media/chase/98d61322-9ea7-473e-b835-8739c77d1e1e/model.chk')
else:
sess.run(tf.initialize_all_variables())
当我使用在其中创建变量的 thensorflow RNN 类时,此代码可以正常工作。