【发布时间】:2021-03-15 18:55:02
【问题描述】:
我正在将 keras 模型移植到 torch,但在 softmax 层之后,我无法复制 keras/tensorflow 的 'categorical_crossentropy' 的确切行为。我有一些解决这个问题的方法,所以我只想了解在计算分类交叉熵时 tensorflow 计算出的内容。
作为一个玩具问题,我设置了标签和预测向量
>>> import tensorflow as tf
>>> from tensorflow.keras import backend as K
>>> import numpy as np
>>> true = np.array([[0.0, 1.0], [1.0, 0.0]])
>>> pred = np.array([[0.0, 1.0], [0.0, 1.0]])
并计算分类交叉熵:
>>> loss = tf.keras.losses.CategoricalCrossentropy()
>>> print(loss(pred, true).eval(session=K.get_session()))
8.05904769897461
这与分析结果不同
>>> loss_analytical = -1*K.sum(true*K.log(pred))/pred.shape[0]
>>> print(loss_analytical.eval(session=K.get_session()))
nan
我深入研究了 keras/tf 的交叉熵的源代码(参见 Softmax Cross Entropy implementation in Tensorflow Github Source Code),并在 https://github.com/tensorflow/tensorflow/blob/c903b4607821a03c36c17b0befa2535c7dd0e066/tensorflow/compiler/tf2xla/kernels/softmax_op.cc 第 116 行找到了 c 函数。在该函数中,有一条注释:
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
// (The subtraction broadcasts along the batch dimension.)
我尝试了实现它:
>>> max_logits = K.max(pred, axis=0)
>>> max_logits = max_logits
>>> xent = K.sum(-true * ((pred - max_logits) - K.log(K.sum(K.exp(pred - max_logits)))))/pred.shape[0]
>>> print(xent.eval(session=K.get_session()))
1.3862943611198906
我还尝试打印xent.eval(session=K.get_session()) 的跟踪,但跟踪的长度约为 95000 行。所以它引出了一个问题:在计算'categorical_crossentropy' 时,keras/tf 到底在做什么?它不返回nan 是有道理的,这会导致训练问题,但是 8 是从哪里来的呢?
【问题讨论】:
标签: python tensorflow machine-learning keras