【问题标题】:Basic Tensorflow Example - Prediction of a Line基本 TensorFlow 示例 - 直线预测
【发布时间】:2017-02-28 19:22:05
【问题描述】:

我正在尝试使用 Tensorflow 创建这个超级简单的示例,但我显然不完全了解 Tensorflow 的 API。

我有以下代码。它最初不是我的——我是从一些演示中找到的,但我不记得我在哪里找到它,否则我会感谢作者。道歉。

保存训练好的线模型

import tensorflow as tf
import numpy as np

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

# Try to find values for W and b that compute y_data = W * x_data + b
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b

# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

# Before starting, initialize the variables.  We will 'run' this first.
init = tf.global_variables_initializer()

# Create a session saver
saver = tf.train.Saver()

# Launch the graph.
sess = tf.Session() 

sess.run(init)

# Fit the line.
for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(W), sess.run(b))
        saver.save(sess, 'linemodel')

好的,没关系。我只想加载模型,然后查询我的模型以获得预测值。这是我尝试的代码:

加载和查询训练好的线模型

# This is going to load the line model
import tensorflow as tf

sess = tf.Session()
new_saver = tf.train.import_meta_graph('linemodel.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
all_vars = tf.global_variables()
for v in all_vars:
    v_ = sess.run(v)
    print("This is {} with value: {}".format(v.name, v_))
    # this works


# None of the below works
# Tried this as well
#fetches = {
#   "input": tf.constant(10, name='input')
#}

#feed_dict = {"input": tf.constant(10, name='input')}
#vals = sess.run(fetches, feed_dict = feed_dict)
# Tried this and it didn't work
# query_value = tf.constant(10, name='query')

# print(sess.run(query_value))

这是一个非常基本的问题,但是我怎样才能只传递一个值并像使用函数一样使用我的行。我是否需要更改构建线模型的方式?我的猜测是计算图没有设置,输出是我们可以获得的实际变量。它是否正确?如果是这样,我应该如何修改这个程序?

【问题讨论】:

    标签: python tensorflow training-data


    【解决方案1】:

    您必须再次创建张量流图并将保存的权重加载到其中。我在您的代码中添加了几行代码,它提供了所需的输出。请检查一下。

    import tensorflow as tf
    import numpy as np
    
    sess = tf.Session() 
    new_saver = tf.train.import_meta_graph('linemodel.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./')) # latest checkpoint
    all_vars = tf.global_variables()
    
    # load saved weights into new variables
    W = all_vars[0]
    b = all_vars[1]
    
    # build TF graph
    x = tf.placeholder(tf.float32)
    y = tf.add(tf.multiply(W,x),b)
    
    # Session
    init = tf.global_variables_initializer()
    print(sess.run(all_vars))
    sess.run(init)    
    for i in range(2):
        x_ip = np.random.rand(10).astype(np.float32) # batch_size : 10
        vals = sess.run(y,feed_dict={x:x_ip})
        print vals
    

    输出:

    [array([ 0.1000001], dtype=float32), array([ 0.29999995], dtype=float32)]
    
    [-0.21707924 -0.18646611 -0.00732027 -0.14248954 -0.54388255 -0.33952206  -0.34291503 -0.54771954 -0.60995424 -0.91694558]
    [-0.45050886 -0.01207681 -0.38950539 -0.25888413 -0.0103816  -0.10003483 -0.04783082 -0.83299863 -0.53189355 -0.56571382]
    

    我希望这会有所帮助。

    【讨论】:

    • 所以你基本上必须重建图形或模型在做什么?结构并没有真正保存?好像不说x = tf.placeholder(tf.float32) y = tf.add(tf.multiply(W,x),b)就无法到达y?谢谢
    • 是的,它只保存变量。我们必须再次创建图表并重新加载保存的变量。我发现有 tf.import_graph_def() 函数来加载模型定义,你可以试试。我可能会稍后尝试。
    • 酷!一如既往地感谢您的帮助。
    猜你喜欢
    • 1970-01-01
    • 2017-10-01
    • 1970-01-01
    • 2020-09-03
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2018-02-28
    • 1970-01-01
    相关资源
    最近更新 更多