您混淆了二元和多类问题的交叉熵。
多类交叉熵
你用的公式是正确的,直接对应tf.nn.softmax_cross_entropy_with_logits:
-tf.reduce_sum(p * tf.log(q), axis=1)
p 和 q 预计是 N 个类别的概率分布。特别是,N 可以是 2,如下例所示:
p = tf.placeholder(tf.float32, shape=[None, 2])
logit_q = tf.placeholder(tf.float32, shape=[None, 2])
q = tf.nn.softmax(logit_q)
feed_dict = {
p: [[0, 1],
[1, 0],
[1, 0]],
logit_q: [[0.2, 0.8],
[0.7, 0.3],
[0.5, 0.5]]
}
prob1 = -tf.reduce_sum(p * tf.log(q), axis=1)
prob2 = tf.nn.softmax_cross_entropy_with_logits(labels=p, logits=logit_q)
print(prob1.eval(feed_dict)) # [ 0.43748799 0.51301527 0.69314718]
print(prob2.eval(feed_dict)) # [ 0.43748799 0.51301527 0.69314718]
请注意,q 正在计算 tf.nn.softmax,即输出概率分布。所以它仍然是多类交叉熵公式,仅适用于 N = 2。
二元交叉熵
这次正确的公式是
p * -tf.log(q) + (1 - p) * -tf.log(1 - q)
虽然在数学上它是多类情况的部分情况,但p 和q 的含义是不同的。最简单的情况,每个p和q都是一个数字,对应A类的一个概率。
重要提示:不要对常见的p * -tf.log(q) 部分和总和感到困惑。以前的p 是一个单热向量,现在它是一个数字,零或一。 q 也一样 - 这是一个概率分布,现在是一个数字(概率)。
如果p 是一个向量,则每个单独的组件都被视为一个独立的二元分类。请参阅this answer,它概述了 tensorflow 中 softmax 和 sigmoid 函数之间的区别。所以p = [0, 0, 0, 1, 0]的定义并不意味着一个单热向量,而是5个不同的特征,其中4个关闭,1个打开。 q = [0.2, 0.2, 0.2, 0.2, 0.2] 的定义意味着 5 个特征中的每一个都以 20% 的概率开启。
这解释了在交叉熵之前使用sigmoid函数:它的目标是将logit压缩到[0, 1]区间。
上面的公式仍然适用于多个独立特征,而这正是tf.nn.sigmoid_cross_entropy_with_logits 计算的结果:
p = tf.placeholder(tf.float32, shape=[None, 5])
logit_q = tf.placeholder(tf.float32, shape=[None, 5])
q = tf.nn.sigmoid(logit_q)
feed_dict = {
p: [[0, 0, 0, 1, 0],
[1, 0, 0, 0, 0]],
logit_q: [[0.2, 0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.2, 0.1, 0.1]]
}
prob1 = -p * tf.log(q)
prob2 = p * -tf.log(q) + (1 - p) * -tf.log(1 - q)
prob3 = p * -tf.log(tf.sigmoid(logit_q)) + (1-p) * -tf.log(1-tf.sigmoid(logit_q))
prob4 = tf.nn.sigmoid_cross_entropy_with_logits(labels=p, logits=logit_q)
print(prob1.eval(feed_dict))
print(prob2.eval(feed_dict))
print(prob3.eval(feed_dict))
print(prob4.eval(feed_dict))
你应该看到最后三个张量是相等的,而prob1只是交叉熵的一部分,所以只有当p是1时它才包含正确的值:
[[ 0. 0. 0. 0.59813893 0. ]
[ 0.55435514 0. 0. 0. 0. ]]
[[ 0.79813886 0.79813886 0.79813886 0.59813887 0.79813886]
[ 0.5543552 0.85435522 0.79813886 0.74439669 0.74439669]]
[[ 0.7981388 0.7981388 0.7981388 0.59813893 0.7981388 ]
[ 0.55435514 0.85435534 0.7981388 0.74439663 0.74439663]]
[[ 0.7981388 0.7981388 0.7981388 0.59813893 0.7981388 ]
[ 0.55435514 0.85435534 0.7981388 0.74439663 0.74439663]]
现在应该清楚的是,在此设置中取 -p * tf.log(q) 和 axis=1 的总和是没有意义的,尽管它在多类情况下是一个有效的公式。