【问题标题】:Loss function for sequences (in Tensorflow 2.0)序列的损失函数(在 Tensorflow 2.0 中)
【发布时间】:2020-11-05 01:17:07
【问题描述】:

我正在研究从英语到德语的句子翻译问题。 所以最终的输出是一个德语序列,我需要检查我的预测有多好。

我在 tensorflow 教程中发现了以下损失函数:

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

但我不知道这个函数是做什么的。我知道(也许我错了)我们不能以直接的方式将 SparseCategoricalCrossentropy 用于序列,我们必须进行某种操作。 但是例如在上面的代码中,我看到 SparseCategoricalCrossentropy 在序列输出中以直接的方式使用。为什么?

mask 变量有什么作用? 能解释一下代码吗?

编辑:教程-https://www.tensorflow.org/tutorials/text/nmt_with_attention

【问题讨论】:

    标签: tensorflow deep-learning loss-function


    【解决方案1】:

    mask 中的mask = tf.math.logical_not(tf.math.equal(real, 0)) 负责处理PADDING

    因此,在您的批次中,您将有不同长度的句子,并且您使用0 填充以使所有句子的长度相等(想想I have an apple v/s It's a good day to play football in the sun

    但是,在损失计算中包含 0 填充部分是没有意义的 - 因此,它首先查看有 0 的索引,然后使用乘法使其损失贡献为 0。

    【讨论】:

      猜你喜欢
      • 2019-12-16
      • 2020-07-28
      • 2020-04-04
      • 1970-01-01
      • 2020-03-16
      • 2019-01-20
      • 2021-12-08
      • 2019-02-03
      • 1970-01-01
      相关资源
      最近更新 更多