【问题标题】:How can i get the predicted class labels in mnist tensorflow python?如何在 mnist tensorflow python 中获得预测的类标签?
【发布时间】:2020-07-01 16:05:39
【问题描述】:
在训练 mnist 张量流模型之后,我测试模型。我不知道我们如何提取预测标签结果列表。它只是给了我准确性和损失值。谁能帮我解决这个问题?
# Evaluate the model and print results
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
y=eval_labels,
num_epochs=1,
shuffle=False)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
【问题讨论】:
标签:
python
tensorflow
testing
predict
mnist
【解决方案1】:
在 MNIST 数据集中,标签以上面的数字编码:
0 = T-shirt/top
1 = Trouser
2 = Pullover
3 = Dress
4 = Coat
5 = Sandal
6 = Shirt
7 = Sneaker
8 = Bag
9 = Ankle boot
因此,当您调用 model.predict(test_image) 时,它将返回一个长度等于 10 的数组,其中每个索引表示模型的置信度,即图像对应于每件不同的服装,完全按照上面的数字在数组顺序中,所以如果数组在索引 3 中的数字最高,则意味着模型预测这件衣服是连衣裙,要获得数组中的最高索引,您可以调用np.argmax(model.predict(test_image)[0])
为了方便索引到真实值的映射,即如果预测的服装是连衣裙或靴子,可以将服装名称存储在数组中,如:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
最后你可以拨打:class_names[np.argmax(model.predict(test_image)[0])]