【发布时间】:2018-11-11 09:48:23
【问题描述】:
我正在尝试将预测图像保存在我用 Tensorflow 编写的 CNN 网络上。在我的代码y_pred_cls 中包含我的预测标签,y_pred_cls 是一个尺寸为 1 x 批量大小的张量。现在,我想将 y_pred_cls 作为一个数组进行迭代,并创建一个包含 pred 类、真实类和一些索引号的文件名,然后找出与预测标签相关的图像并使用 imsave 保存为图像。
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_writer.add_graph(sess.graph)
print("{} Start training...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
print("{} Open Tensorboard at --logdir {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), tensorboard_dir))
for epoch in range(FLAGS.num_epochs):
print("{} Epoch number: {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1))
step = 1
# Start training
while step < train_batches_per_epoch:
batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
opt, train_acc = sess.run([optimizer, accuracy], feed_dict={x: batch_xs, y_true: batch_ys})
# Logging
if step % FLAGS.log_step == 0:
s = sess.run(sum, feed_dict={x: batch_xs, y_true: batch_ys})
train_writer.add_summary(s, epoch * train_batches_per_epoch + step)
step += 1
# Epoch completed, start validation
print("{} Start validation".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
val_acc = 0.
val_count = 0
cm_running_total = None
for _ in range(val_batches_per_epoch):
batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
acc, loss , conf_m= sess.run([accuracy, cost, tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes)],
feed_dict={x: batch_tx, y_true: batch_ty})
if cm_running_total is None:
cm_running_total = conf_m
else:
cm_running_total += conf_m
val_acc += acc
val_count += 1
val_acc /= val_count
s = tf.Summary(value=[
tf.Summary.Value(tag="validation_accuracy", simple_value=val_acc),
tf.Summary.Value(tag="validation_loss", simple_value=loss)
])
val_writer.add_summary(s, epoch + 1)
print("{} -- Training Accuracy = {:.4%} -- Validation Accuracy = {:.4%} -- Validation Loss = {:.4f}".format(
datetime.now().strftime('%Y-%m-%d %H:%M:%S'), train_acc, val_acc, loss))
# Reset the dataset pointers
val_preprocessor.reset_pointer()
train_preprocessor.reset_pointer()
print("{} Saving checkpoint of model...".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
# save checkpoint of the model
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch.ckpt' + str(epoch+1))
save_path = saver.save(sess, checkpoint_path)
print("{} Model checkpoint saved at {}".format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), checkpoint_path))
batch_tx,batch_ty 分别是我的 RGB 数据和标签。
提前致谢。
【问题讨论】:
-
您应该在外循环之前只创建一次
tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes),否则您将在图中创建同一操作的多个实例。 -
嘿@jdehesa,我应该在哪里创建
tf.confusion_matrix()?因为如果我在循环之后创建混淆矩阵,我只会得到每个批次的混淆矩阵,并且对于有混淆矩阵,我必须提供x: batch_tx, y_true: batch_ty否则我会得到一个错误。 -
在你定义了
y_true_cls和y_pred_cls之后,你应该能够在图构建时间(在任何训练循环之前)执行类似conf_mat = tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes)的操作。然后在循环内做acc, loss , conf_m= sess.run([accuracy, cost, conf_mat], feed_dict={x: batch_tx, y_true: batch_ty})(除非我不明白什么......) -
我按照你说的做了,在开始我的
tf.Session()之前添加conf_mat = tf.confusion_matrix(y_true_cls, y_pred_cls, FLAGS.num_classes),在循环中我做了acc, loss , conf_m= sess.run([accuracy, cost, conf_mat], feed_dict={x: batch_tx, y_true: batch_ty}),但仍然得到一个批次的混淆矩阵,我需要总结混淆矩阵所有批次都得到一个完整的混淆矩阵,结果和我以前做的完全一样。 -
啊,好吧,这就是你的意思,对,是的,你仍然必须这样做,我的意思是,如果你在循环中继续调用
tf.confusion_matrix,你会创建许多相同的副本留在你的图表中的 TensorFlow 操作,所以最好只创建一次并重用它。但是是的,否则您的代码是正确的。
标签: python python-3.x numpy matplotlib tensorflow