【问题标题】:Get label prediction from Cifar-10 model从 Cifar-10 模型获取标签预测
【发布时间】:2016-06-19 07:36:15
【问题描述】:
我目前正在研究 tensorflow 的 Cifar-10 教程。我想更改评估,以便我可以看到每张图像我的模型的预测是什么,以及它是真还是假。我在第一部分苦苦挣扎:如果我打印预测 (sess.run([top_k_op])) 我会得到真/假值,我假设预测是否正确。但是,如果我尝试打印实际预测(到目前为止,我尝试打印 logits,并打印 top_k_op 张量),我会得到一些数字或值,但没有任何看起来像标签的东西。我必须对我的代码进行哪些更改才能真正看到我的模型预测的标签?
【问题讨论】:
标签:
machine-learning
tensorflow
【解决方案1】:
您想首先评估logits。这是网络之外的类的概率分布。具有较高值的张量索引将为您的标签提供最有可能的类别。
您可以使用tf.argmax 获取索引,然后使用标签中的索引将其打印出来
print labels[index]
【解决方案2】:
您可以通过查看here找出答案
在 svhn.py 中,在第 116 行打印预测标签:print (step, int(test_labels[0]))
我通过以下方式清楚地做到了:
classification = sess.run(top_k_predict_op)
print (step, int(test_labels[0]))
print "network predicted:", classification[0], "for real label:", test_labels
确保您预测的是 24*24 图像,以防您使用原始版本的 TensorFlow CIFAR-10 模型训练模型。