【发布时间】:2018-09-26 18:02:45
【问题描述】:
我在我自己的数据集上使用@kratzert 编写的finetune AlexNet 架构,它工作正常(我从这里得到代码:https://github.com/kratzert/finetune_alexnet_with_tensorflow),我想弄清楚如何从他的代码中构建混淆矩阵。我曾尝试使用tf.confusion_matrix(labels, predictions, num_classes) 来构建混淆矩阵,但我不能。我很困惑标签和预测的值应该是什么,我的意思是,我知道应该是什么,但是每次输入这些值时都会出错。任何人都可以帮助我或查看代码(以上链接)并指导我吗?
我在计算准确度之后在finetune.py中添加了这两行,以使标签和预测作为类的数量。
with tf.name_scope("accuracy"):
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
**true_class = tf.argmax(y, 1)
predicted_class = tf.argmax(score, 1)**
在保存模型检查点之前,我在会话的最底部添加了tf.confusion_matrix()
for _ in range(val_batches_per_epoch):
img_batch, label_batch = sess.run(next_batch)
acc, cost = sess.run([accuracy, loss], feed_dict={x: img_batch,
y: label_batch,
keep_prob: 1.})
test_acc += acc
test_count += 1
test_acc /= test_count
print("{} Validation Accuracy = {:.4f} -- Validation Loss = {:.4f}".format(datetime.now(),test_acc, cost))
print("{} Saving checkpoint of model...".format(datetime.now()))
**print(sess.run(tf.confusion_matrix(true_class, predicted_class, num_classes)))**
# save checkpoint of the model
checkpoint_name = os.path.join(checkpoint_path,
'model_epoch'+str(epoch+1)+'.ckpt')
save_path = saver.save(sess, checkpoint_name)
print("{} Model checkpoint saved at {}".format(datetime.now(),
checkpoint_name))
我也尝试过其他地方,但每次都会出错:
Caused by op 'Placeholder_1', defined at:
File "/home/armin/Desktop/Alexnet_DataPipeline/finetune.py", line 85, in <module>
y = tf.placeholder(tf.float32, [batch_size, num_classes])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/array_ops.py", line 1777, in placeholder
return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 4521, in placeholder
"Placeholder", dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
op_def=op_def)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [128,3]
任何帮助将不胜感激,谢谢。
【问题讨论】:
-
你能贴出你的代码和错误(重要的部分,不是完整的代码)吗?
-
我添加了部分代码和我添加的行来计算混淆矩阵和我的错误
标签: tensorflow scikit-learn tensorboard tensorflow-datasets