【问题标题】:using saved model to make prections in tensorflow使用保存的模型在张量流中进行预测
【发布时间】:2018-11-22 03:59:13
【问题描述】:

我有这段代码可以在 tensorflow 中训练恢复模型。但我怎样才能做出预测。

def train_neural_network(x):
    prediction=neural_network_model(x)
    cost=tf.nn.softmax_cross_entropy_with_logits_v2(logits = prediction, labels = y)
    optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        #sess.run(tf.initialize_all_variables())
        sess.run(tf.global_variables_initializer())
        for epoch in range(hm_epochs):
            epoch_loss = 0
            i = 0
            #while i < len(train_x):
            t = len(train_x)
            f = t%batch_size
            while i < (t-f):
                start = i
                end = i+batch_size
                batch_x = np.array(train_x[start:end])
                batch_y = np.array(train_y[start:end])


                _, c = sess.run([optimizer, cost], feed_dict={x: batch_x, y: batch_y})
                epoch_loss += c
                #epoch_loss = epoch_loss + c
                i+=batch_size
                #i = i + batch_size
            print('Epoch =', epoch+1, '/',hm_epochs,'loss:',epoch_loss)

        save_path = saver.save(sess, "sesionestensorflow/model1.ckpt")
        print("Model saved in path: %s" % save_path)


        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        print('Accuracy:',accuracy.eval({x:test_x, y:test_y}))

我看到了这个答案,但我无法做出预测。 Using saved model for prediction in tensorflow

【问题讨论】:

    标签: tensorflow neural-network


    【解决方案1】:

    您只需要创建一个 empy 图、定义网络、加载保存的权重并运行推理。

    prediction = tf.argmax(neural_network_model(x), 1)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # load the trained weights into the model
        saver.restore(sess, "sesionestensorflow/model1.ckpt")
        # just use the model
        out = sess.run(prediction, feed_dict={x: <your input> })
    

    【讨论】:

    • 好的,结果是这样的,但我希望一个零数组有一个数字 1 。进料的dict = [[0.6737 0.6737 0.6737 41.606056 16.11666667]] OUT [[62.6507 -2025.7719 136.06538 -2349.9055 -4071.8032 -1727.9988 78.84228 -305.2752 -1380.6096 299.88174 1483.1726 1353.7816 2282.4849 2496.1821 1986.554] 跨度>
    • 因为预测是模型的输出。您正在寻找概率较高的班级。我更新答案
    猜你喜欢
    • 2018-09-28
    • 2020-04-14
    • 1970-01-01
    • 2017-11-22
    • 1970-01-01
    • 2019-12-20
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多