不幸且令人困惑的是,tf.contrib.metrics.streaming_auc() 采用与 tf.confusion_matrix() 不同的参数。您应该不将类标签传递给这个类标签,这与 tf.confusion_matrix() 不同。
这里predictions 应该是每个类的概率(你从y_pred_class 得到的值tf.argmax()...)。
标签也应该是 one_hot 格式,而且是布尔张量。 (您可以简单地将其转换为布尔值,0 将变为假,1 将变为真。)
查看您的代码,另一个问题是您必须为您运行的每个批次运行 AUC update_op,以便 AUC 可以累积数据。另一方面,一旦完成一定数量,您需要重置 AUC 的内部变量为零,以便进一步计算,例如用于验证的 AUC,或下一个训练批次'不包括以前的数据。为此,您需要能够在所有变量中找出这些变量,因此我将其放在变量范围内,我称之为“AUC”。综上所述,我更改了代码的几个部分(将在底部附上完整的运行、测试代码以供参考):
定义 auc 的位置(注意,我已从 contrib 中已弃用的版本切换到当前版本 tf.metrics.auc()):
with tf.variable_scope( "AUC" ):
auc, auc_update_op = tf.metrics.auc( predictions=y_pred, labels=y_true, curve = 'ROC' )
auc_variables = [ v for v in tf.local_variables() if v.name.startswith( "AUC" ) ]
auc_reset_op = tf.initialize_variables( auc_variables )
show_progress() 函数(注意msg = ... 行中的一些格式更改,这些并非绝对必要,但仅反映我的口味):
def show_progress(epoch, feed_dict_train, feed_dict_validate, val_loss):
acc, auc_value = session.run([ accuracy, auc_update_op ], feed_dict=feed_dict_train)
session.run( auc_reset_op )
val_acc, val_auc_value = session.run([ accuracy, auc_update_op ], feed_dict=feed_dict_validate)
session.run( auc_reset_op )
msg = 'Training Epoch {} --- Tr Acc: {:>6.1%}, Tr AUC: {:>6.1%}, Val Acc: {:>6.1%}, Val AUC: {:>6.1%}, Val Loss: {:.3f}'
print(msg.format(epoch + 1, acc, auc_value, val_acc, val_auc_value, val_loss))
最后,在train() 函数中,注意插入的auc_update_op:
session.run( [ optimizer, auc_update_op ], feed_dict=feed_dict_tr)
另外,将会话初始化放在最后,就像“最佳实践”一样,这不是绝对必要的:
with tf.Session() as session:
#session.run(tf.initialize_all_variables())
# need to init local variables for internal auc calculations
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
session.run( init )
train( num_iteration=10000 )
有了这个,输出是:
训练 Epoch 1 --- Tr Acc:51.6%,Tr AUC:51.0%,Val Acc:46.1%,Val AUC:46.1%,Val Loss:4.893
训练时期 2 --- Tr Acc:53.9%,Tr AUC:53.0%,Val Acc:53.5%,Val AUC:55.3%,Val Loss:0.691
训练 Epoch 3 --- Tr Acc:65.6%,Tr AUC:63.9%,Val Acc:65.2%,Val AUC:69.0%,Val Loss:0.647
训练时期 4 --- Tr Acc:71.1%,Tr AUC:71.6%,Val Acc:68.0%,Val AUC:74.8%,Val Loss:0.586
训练 Epoch 6 --- Tr Acc:73.0%,Tr AUC:76.8%,Val Acc:69.5%,Val AUC:75.9%,Val Loss:0.588
训练时期 7 --- Tr Acc:77.3%,Tr AUC:82.4%,Val Acc:73.8%,Val AUC:77.7%,Val Loss:0.563
训练 Epoch 8 --- Tr Acc:81.2%,Tr AUC:87.0%,Val Acc:78.9%,Val AUC:85.2%,Val Loss:0.475
训练 Epoch 9 --- Tr Acc:83.6%,Tr AUC:90.9%,Val Acc:75.0%,Val AUC:83.5%,Val Loss:0.517
训练 Epoch 11 --- Tr Acc:91.8%,Tr AUC:94.3%,Val Acc:73.0%,Val AUC:81.4%,Val Loss:0.646
(见附件train.html或train.ipynb下面)
数据:
Data is on Kaggle, from the Dogs vs. Cats competition (train.zip)
三种格式的可运行测试主代码(请注意,您需要根据您的特定设置调整代码中的数据路径,并根据您的硬件容量调整批量大小):
train.ipynbtrain.htmltrain.py
必需(导入)文件,将其放在与train 相同的文件夹中:
dataset.py