【问题标题】:tensorflow accuracy python张量流精度python
【发布时间】:2018-02-25 06:04:44
【问题描述】:

我遇到了 CNN 的问题。拥有下面的代码,我试图识别 dicLabelsNumbers 字典中定义的数字。我有 90 张图片,每个数字 0-9 9 张,重复 100 次。但是我的代码中一定有问题,因为无论我更改 CNN 中的任何参数,准确度都没有差异,准确度始终是精确数字 0.101123594。它永远不会改变。请告诉我:

  1. 如果我从光盘读取图片的过程是正确的
  2. 获得始终相同精度的原因是什么
  3. 如果我知道如果我在 90 张图片上训练 CNN,然后在相同图片上评估 CNN,那么我应该获得 100% 的准确度。对吗?

我的图片是严格尺寸为 16 x 18 的 bmp。

    dicLabelsNumbers = {
'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9
      }

    def cnn_model_mk(features, labels, mode):

      input_layer = tf.reshape(features, [-1, 16, 18, 1])

      conv1 = tf.layers.conv2d(inputs=input_layer, filters=16, kernel_size=[3, 3], padding="same", activation=tf.nn.relu)
      pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=1)
      conv2 = tf.layers.conv2d(inputs=pool1, filters=32, kernel_size=[3, 3], padding="same", activation=tf.nn.relu)
      pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=1)
      pool2_flat = tf.reshape(pool2, [-1, 14*16*32])
      dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.softmax)
      dropout = tf.layers.dropout(inputs=dense, rate=0.2, training=mode == tf.estimator.ModeKeys.TRAIN)
      logits = tf.layers.dense(inputs=dropout, units=len(dicLabelsNumbers))

      predictions = {
          "classes": tf.argmax(input=logits, axis=1),
          "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
      }
      if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

      if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
        train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

      eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])}
      return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)


    def _parse_function(filename, label):
      image_string = tf.read_file(filename)
      image_decoded = tf.image.decode_bmp(image_string)
      image_resized = tf.image.resize_images(image_decoded, [16, 18])
      image_resized = tf.image.rgb_to_grayscale(image_resized)
      image_resized = tf.reshape(image_resized, [16, 18, 1])
      return image_resized, label

    def my_input_fn():
      filespath = "./Signs/"
      root = xml.etree.ElementTree.parse(filespath + 'char.xml').getroot()

      filenames = []
      labels = []

      i = 0
      for child in root:
        filename = filespath + child.get("file")
        label = dicLabelsNumbers[child.get("tag")]

        filenames.append(filename)
        labels.append(label)
        i += 1

        if i > 90:
          break

      filenames = tf.constant(filenames)
      labels = tf.constant(labels)

      dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
      dataset = dataset.map(_parse_function)
      dataset = dataset.repeat(100)
      dataset_batched = dataset.batch(1)
      iterator = dataset_batched.make_one_shot_iterator()

      return iterator.get_next()

    def main(unused_argv):

      mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
      mnist_classifier.train(input_fn=my_input_fn, steps=10000)
      eval_results = mnist_classifier.evaluate(input_fn=my_input_fn)
      print(eval_results)



    if __name__ == "__main__":
      tf.app.run()

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    同样的准确性是“机会”:选择正确标签的机会是十分之一。您的网络尚未了解这些图像。

    您是对的,如果您在相同的图像上进行训练和测试,则可以对数据集进行完全建模并获得 100% 的准确度,但仍需要训练。使用这么小的数据集,您可能需要很多 epoch。

    如果您还没有,您应该尝试对输入图像进行白化(有一个用于白化的 TensorFlow 函数)并标准化您的输入数据,使像素值在 0-1 范围内(而不是 0-255)。

    我不熟悉 TensorFlow 数据集类,但您似乎将批量大小设置为 1。我认为这可能会导致您的权重振荡而不是收敛。

    你能输出你的损失吗?如果输入有问题,损失会显示出来。损失应该在训练步骤中以某种方式减少。

    【讨论】:

    • 我设法解决了我的问题。在这种情况下,不需要 dropout 层。删除它后,CNN 开始正常工作。
    猜你喜欢
    • 2018-11-25
    • 1970-01-01
    • 2018-08-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-08-20
    • 2023-03-28
    • 1970-01-01
    相关资源
    最近更新 更多