【发布时间】:2017-10-08 21:32:07
【问题描述】:
我有一个简单的目标,即在 tensorflow 中训练模型并在以后恢复它,以便继续训练或使用某些功能/操作。
这是模型的简单示例
import tensorflow as tf
import numpy as np
BATCH_SIZE = 3
VECTOR_SIZE = 1
LEARNING_RATE = 0.1
x = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
name='input_placeholder')
y = tf.placeholder(tf.float32, [BATCH_SIZE, VECTOR_SIZE],
name='labels_placeholder')
W = tf.get_variable('W', [VECTOR_SIZE, BATCH_SIZE])
b = tf.get_variable('b', [VECTOR_SIZE], initializer=tf.constant_initializer(0.0))
y_hat = tf.matmul(W, x) + b
predict = tf.matmul(W, x) + b
total_loss = tf.reduce_mean(y-y_hat)
train_step = tf.train.AdagradOptimizer(LEARNING_RATE).minimize(total_loss)
X = np.ones([BATCH_SIZE, VECTOR_SIZE])
Y = np.ones([BATCH_SIZE, VECTOR_SIZE])
all_saver = tf.train.Saver()
sess= tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([train_step], feed_dict = {x: X, y:Y}))
save_path = r'C:\some_path\save\\'
all_saver.save(sess,save_path)
现在我们在这里恢复它:
meta_path = r'C:\some_path\save\.meta'
new_all_saver = tf.train.import_meta_graph(meta_path)
graph = tf.get_default_graph()
all_ops = graph.get_operations()
for el in all_ops:
print(el)
在恢复的操作中,甚至无法从原始代码中找到predict 或train_step。我需要在保存之前命名此操作吗?我怎样才能找回predict 并运行类似的东西
sess=tf.Session()
sess.run([predict], feed_dict = {x:X})
附:我阅读了很多关于在 tensorflow 中保存和恢复的教程,但仍然不太了解它是如何工作的。
【问题讨论】:
标签: tensorflow save restore