【问题标题】:Tensorflow: I get something wrong in accuracyTensorflow:我的准确性有问题
【发布时间】:2017-11-18 19:09:09
【问题描述】:

我只是运行一个简单的代码,并希望在训练后获得准确性。我加载了我保存的模型,但是当我想获得准确性时,我出错了。为什么?

# coding=utf-8
from  color_1 import read_and_decode, get_batch, get_test_batch
import AlexNet
import cv2
import os
import time
import numpy as np
import tensorflow as tf
import AlexNet_train
import math

batch_size=128
num_examples = 1000
crop_size=56

def evaluate(test_x, test_y):
    image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
    label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')

    y = AlexNet.inference(image_holder,evaluate,None)

    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
    accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        sess.run(init_op)
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        ckpt=tf.train.get_checkpoint_state(AlexNet_train.MODEL_SAVE_PATH)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
            saver.restore(sess, os.path.join(AlexNet_train.MODEL_SAVE_PATH, ckpt_name))
            print('Loading success, global_step is %s' % global_step)
            step=0

            image_batch, label_batch = sess.run([test_x, test_y])
            accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,
                                                              label_holder: label_batch})
            print("After %s training step(s),validation "
                  "precision=%g" % (global_step, accuracy_score))
        coord.request_stop()  
        coord.join(threads)

def main(argv=None):
    test_image, test_label = read_and_decode('val.tfrecords')

    test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)

    evaluate(test_images, test_labels)


if __name__=='__main__':
    tf.app.run()

这是错误,它说我的代码中的这一行是错误的:“correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))”

Traceback (most recent call last):
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
    tf.app.run()
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
    evaluate(test_images, test_labels)
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 45, in evaluate
    label_holder: label_batch})
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1
     [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]

Caused by op u'ArgMax_1', defined at:
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 80, in <module>
    tf.app.run()
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 76, in main
    evaluate(test_images, test_labels)
  File "/home/vrview/tensorflow/example/char/tfrecords/AlexNet/Alex_save/AlexNet_test.py", line 22, in evaluate
    correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(label_holder,1))
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/math_ops.py", line 263, in argmax
    return gen_math_ops.arg_max(input, axis, name)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 168, in arg_max
    name=name)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Expected dimension in the range [-1, 1), but got 1
     [[Node: ArgMax_1 = ArgMax[T=DT_INT32, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_y-input_0, ArgMax_1/dimension)]]

如何解决?

【问题讨论】:

  • “accuracy_score=sess.run(accuracy,feed_dict={image_holder: image_batch,label_holder: label_batch}) 行似乎有误。但是不知道怎么解决
  • 错误信息告诉你的事情是什么?它直截了当地说 arg_max 参数需要在 [-1, 1) 范围内,但是您传递的是 1,这是无效的。修复它不起作用还是您只是没有阅读错误消息?
  • 我当然知道错误,只是不知道如何解决。我不擅长这个。那么你能告诉我如何解决吗?我应该添加一些东西吗?
  • 好吧,错误消息告诉您传递不是数字 1 的东西,而您传递的是数字 1。一种可能的解决方案是传递不是数字 1 的东西(并且更好的是,实际上是在 [-1,1) 范围内的东西,例如 0.5 或 0.9 或其他。我不知道那个论点实际上做了什么,所以我不知道最好的数字,但你至少可以看看它是否能解决你的错误。
  • e...我看到了'tf.equal'和'tf.argmax'的用法,但似乎它的数量没有限制。所以我感到困惑。我输入的是图像,大小是[56,56,3],'y'是全连接层的输出。

标签: python tensorflow


【解决方案1】:

在这里扮演与问题相关的this answer

tf.argmax's definition 状态:

轴:张量。必须是以下类型之一:int32、int64。 int32, 0 。描述输入的哪个轴 要减少的张量。

看来,在张量的最后一个轴上运行argmax 的唯一方法是给它axis=-1,因为函数定义中的“严格小于”符号。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2016-03-18
    • 1970-01-01
    • 1970-01-01
    • 2018-06-15
    • 2018-04-14
    • 2017-02-26
    • 1970-01-01
    相关资源
    最近更新 更多