标签编码似乎正确。如果您有多个正确的标签,[1 0 1 0 ... 1] 看起来完全没问题。 Denny 的post 中使用的损失函数是tf.nn.softmax_cross_entropy_with_logits,这是针对多类问题的损失函数。
计算 logits 和标签之间的 softmax 交叉熵。
测量离散分类任务中的概率误差
哪些类是互斥的(每个条目都属于一个类)。
在多标签问题中,你应该使用tf.nn.sigmoid_cross_entropy_with_logits:
在给定 logits 的情况下计算 sigmoid 交叉熵。
衡量离散分类任务中的概率误差,其中每个类都是独立的,而不是互斥的。例如,可以执行多标签分类,其中一张图片可以同时包含大象和狗。
损失函数的输入是 logits (WX) 和目标(标签)。
修正准确度度量
为了正确衡量多标签问题的准确性,需要更改以下代码。
# Calculate Accuracy
with tf.name_scope("accuracy"):
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
当你可以有多个正确的标签时,上面的correct_predictions 的逻辑是不正确的。例如,假设num_classes=4,标签 0 和 2 是正确的。因此,您的 input_y=[1, 0, 1, 0]. correct_predictions 需要打破索引 0 和索引 2 之间的关系。我不确定tf.argmax 如何打破关系,但如果它通过选择较小的索引来打破关系,则标签 2 的预测总是被认为是错误的,这肯定会损害您的准确性。
实际上在多标签问题中,precision and recall 是比准确度更好的指标。您也可以考虑使用precision@k (tf.nn.in_top_k) 来报告分类器性能。