【发布时间】:2017-02-28 19:22:05
【问题描述】:
我正在尝试使用 Tensorflow 创建这个超级简单的示例,但我显然不完全了解 Tensorflow 的 API。
我有以下代码。它最初不是我的——我是从一些演示中找到的,但我不记得我在哪里找到它,否则我会感谢作者。道歉。
保存训练好的线模型
import tensorflow as tf
import numpy as np
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
# Create a session saver
saver = tf.train.Saver()
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for step in range(201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
saver.save(sess, 'linemodel')
好的,没关系。我只想加载模型,然后查询我的模型以获得预测值。这是我尝试的代码:
加载和查询训练好的线模型
# This is going to load the line model
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
v_ = sess.run(v)
print("This is {} with value: {}".format(v.name, v_))
# this works
# None of the below works
# Tried this as well
#fetches = {
# "input": tf.constant(10, name='input')
#}
#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')
# print(sess.run(query_value))
这是一个非常基本的问题,但是我怎样才能只传递一个值并像使用函数一样使用我的行。我是否需要更改构建线模型的方式?我的猜测是计算图没有设置,输出是我们可以获得的实际变量。它是否正确?如果是这样,我应该如何修改这个程序?
【问题讨论】:
标签: python tensorflow training-data