【问题标题】:tf.argmax() returning unexpected resultstf.argmax() 返回意外结果
【发布时间】:2018-04-19 14:31:44
【问题描述】:

我最近在做一个基于tensorflow CNN,MNIST数据集的项目,带有服务器接口。

在预测部分,我使用 tf.argmax() 得到最大的 logit,这将是预测值。但是,它返回的值似乎不是正确的答案。

predict函数大概是这样的:

    self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
    self._create_model()

    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state('../checkpoints/')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)

    pred = tf.nn.softmax(self.logits)
    prediction = tf.argmax(pred, 1)
    logit = sess.run(pred)
    result = sess.run(prediction)[0]
    print(logit)
    print(result)

    return result

结果是:

127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]]
1

如您所见,logits 显示最大数的索引是 5,但是 tf.argmax() 给了我 1 代替。

顺便说一下,my model 是基本的 MNIST CNN 模型,您可以在链接中看到。

那么这个 tf.argmax() 函数是怎么回事,还是我的代码有问题?

【问题讨论】:

  • 为什么不这样尝试呢? pred = tf.nn.softmax(self.logits) logit = sess.run(pred) result = tf.argmax(logit)

标签: python tensorflow mnist argmax


【解决方案1】:

由于您的logit(pred) 和result(prediction[0]) 来自两个不同的sess.run,我想知道运行之间是否存在一些差异。例如,您在图中有一个迭代器将输入发送到模型。通过不同的运行,迭代器发送不同的数据,导致不同的预测。如果您将predprediction 放在同一个sess.run 中会很有趣,如下所示:

logit, result = sess.run((pred, prediction))
print(logit)
print(result[0])

【讨论】:

    猜你喜欢
    • 2017-11-14
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-05-10
    • 2016-02-24
    相关资源
    最近更新 更多