【问题标题】:TensorFlow:Evaluate test set multiple times but get different accuracyTensorFlow:多次评估测试集但获得不同的准确性
【发布时间】:2017-12-20 23:49:36
【问题描述】:

我已经使用CNN训练了MNIST的模型,但是当我在训练后用测试数据检查模型的准确性时,我发现我的准确性会有所提高。这是代码。

BATCH_SIZE = 50
LR = 0.001              # learning rate
mnist = input_data.read_data_sets('./mnist', one_hot=True)  # they has been normalized to range (0,1)
test_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]

def new_cnn(imageinput, inputshape):
    weights = tf.Variable(tf.truncated_normal(inputshape, stddev = 0.1),name = 'weights')
    biases = tf.Variable(tf.constant(0.05, shape = [inputshape[3]]),name = 'biases')
    layer = tf.nn.conv2d(imageinput, weights, strides = [1, 1, 1, 1], padding = 'SAME')
    layer = tf.nn.relu(layer)
    return weights, layer

tf_x = tf.placeholder(tf.float32, [None, 28 * 28])
image = tf.reshape(tf_x, [-1, 28, 28, 1])              # (batch, height, width, channel)
tf_y = tf.placeholder(tf.int32, [None, 10])            # input y

# CNN
weights1, layer1 = new_cnn(image, [5, 5, 1, 32])
pool1 = tf.layers.max_pooling2d(
    layer1,
    pool_size=2,
    strides=2,
)           # -> (14, 14, 32)
weight2, layer2 = new_cnn(pool1, [5, 5, 32, 64])    # -> (14, 14, 64)
pool2 = tf.layers.max_pooling2d(layer2, 2, 2)    # -> (7, 7, 64)
flat = tf.reshape(pool2, [-1, 7 * 7 * 64])          # -> (7*7*64, )
hide = tf.layers.dense(flat, 1024, name = 'hide')              # hidden layer
output = tf.layers.dense(hide, 10, name = 'output')
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)           # compute cost
accuracy = tf.metrics.accuracy( labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]
train_op = tf.train.AdamOptimizer(LR).minimize(loss)



sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy
sess.run(init_op)     # initialize var in graph
saver = tf.train.Saver()
for step in range(101):
    b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
    _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
    if step % 50 == 0:
        print(loss_)
        accuracy_, loss2 = sess.run([accuracy, loss], {tf_x: test_x, tf_y: test_y })
        print('Step:', step, '| test accuracy: %f' % accuracy_)

为了简化问题,我只使用了 100 次训练迭代。并且测试集的最终准确率约为0.655000

但是当我运行以下代码时:

for i in range(5):
  accuracy2 = sess.run(accuracy, {tf_x: test_x, tf_y: test_y })
  print(sess.run(weight2[1,:,0,0])) # To show that the model parameters won't update 
  print(accuracy2)

输出是

[-0.06928255 -0.13498515  0.01266837  0.05656774  0.09438231]
0.725875
[-0.06928255 -0.13498515  0.01266837  0.05656774  0.09438231]
0.7684
[-0.06928255 -0.13498515  0.01266837  0.05656774  0.09438231]
0.79675
[-0.06928255 -0.13498515  0.01266837  0.05656774  0.09438231]
0.817
[-0.06928255 -0.13498515  0.01266837  0.05656774  0.09438231]
0.832187

这让我很困惑,谁能告诉我怎么了? 感谢您的耐心等待!

【问题讨论】:

  • 请包含完整代码。例如 wgat 你使用 keep_prob 吗?
  • @lejlot 对此感到抱歉,我删除了多余的部分。
  • 您确定权重不会改变吗?打印它们是不够的,因为单次通过可能会有非常小的变化,可能超出打印的小数点。您的代码执行不排除这种情况。
  • @Eric Platon 但是在评估测试集的过程中可能不会执行训练操作。

标签: machine-learning tensorflow deep-learning conv-neural-network tensorflow-serving


【解决方案1】:

tf.metrics.accuracy 并不像你想象的那么简单。看看它的文档:

accuracy 函数创建两个局部变量,total
count,用于计算频率 predictions 匹配 labels。这个频率最终是 返回为accuracy:一个简单除法的幂等操作 totalcount

在内部,is_correct 操作计算 Tensor 与 元素 1.0 其中predictions 的对应元素和 labels 匹配,否则为 0.0。然后update_op 递增 totalweightsis_correct,它增加了count,减少的总和 weights.

为了估计数据流上的度量,函数 创建一个 update_op 操作来更新这些变量和 返回accuracy

...

返回:

  • accuracy:一个Tensor代表准确率,total的值除 count
  • update_op:递增totalcount 变量的操作 适当且其值与accuracy 匹配。

请注意,它返回一个元组,你取第二项,即update_opupdate_op 的连续调用被视为数据流,这不是您打算做的(因为每次评估在训练期间都会影响未来的评估)。其实这个运行指标是pretty counter-intuitive

您的解决方案是使用简单明了的精度计算。将此行更改为:

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf_y, axis=1), tf.argmax(output, axis=1)), tf.float32))

您将获得稳定的精度计算。

【讨论】:

    猜你喜欢
    • 2023-03-08
    • 2015-06-18
    • 2018-04-12
    • 2017-08-07
    • 2016-08-08
    • 2020-03-23
    • 1970-01-01
    • 1970-01-01
    • 2018-10-08
    相关资源
    最近更新 更多