【发布时间】:2018-02-25 15:47:21
【问题描述】:
我按照本教程 https://www.tensorflow.org/tutorials/layers 训练了一个模型,用于识别 MNIST 集中的手写数字。
以下代码按预期工作,并为集合中的每个图像打印概率和类别
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
tf.reset_default_graph()
with tf.Session() as sess:
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir="model/")
pred = mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
shuffle=False))
for p in pred:
print(p)
但是,当我尝试仅使用
预测一张图像时mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data[0]},
shuffle=False))
我的程序失败,TensorFlow 报告
InvalidArgumentError: Input to reshape is a tensor with 128 values,
but the requested shape requires a multiple of 784
这让我感到困惑,因为当我打印集合中第一张图像的长度时,它报告为 784
print("length of input: {}".format(len(train_data[0]))
如何仅获得一张图像的预测结果?
【问题讨论】:
标签: python tensorflow mnist