【问题标题】:Understanding Cross Entropy Loss了解交叉熵损失
【发布时间】:2018-09-03 22:36:02
【问题描述】:

我看到很多关于 CEL 或二元交叉熵损失的解释,在基本事实是 0 或 1 的上下文中,然后你会得到如下函数:

def CrossEntropy(yHat, y):
    if yHat == 1:
      return -log(y)
    else:
      return -log(1 - y)

但是,当您的 yHat 不是离散的 0 或 1 时,我对 BCE 的工作方式感到困惑。例如,如果我想查看一个 MNIST 数字的重建损失,其中我的基本事实是 0

编辑:

抱歉,让我为我的困惑提供更多背景信息。在有关 VAE 的 PyTorch 教程中,他们使用 BCE 来计算重建损失,其中 yhat(据我所知,不是离散的)。见:

https://github.com/pytorch/examples/blob/master/vae/main.py

实现工作......但我不明白在这种情况下如何计算 BCE 损失。

【问题讨论】:

  • 对于带有图像的自动编码器,您可以将像素值标准化为[0, 1] 范围,然后使用 BCE 逐像素
  • 当然可以,但他们就是在这里做的吗?
  • 在您发布的代码中,该代码处理的不仅仅是 0 和 1。第一个 if 语句处理 1 的情况,但 else 语句处理所有其他值,不仅仅是0
  • 查看 pytorch.org/docs/master/…pytorch.org/docs/master/nn.html#torch.nn.BCEWithLogitsLoss 他们从数据中获取 sigmoid 函数,因此它被归一化为 [0, 1]
  • @stackoverflowuser2010 - 是的,但是如果它采用任何不是 0 或 1 的值,则此代码将无法正常工作(如计算正确的 CE 损失)。

标签: python machine-learning neural-network loss-function


【解决方案1】:

交叉熵测量任意两个概率分布之间的距离。在您所描述的(VAE)中,MNIST 图像像素被解释为像素“开/关”的概率。在这种情况下,您的目标概率分布根本不是狄拉克分布(0 或 1),而是可以具有不同的值。见the cross entropy definition on Wikipedia

以上述为参考,假设您的模型输出一个 0.7 像素的重建。这实质上是说您的模型估计 p(pixel=1) = 0.7,因此 p(pixel=0) = 0.3。
如果目标像素仅为 0 或 1,则如果真实像素为 0,则该像素的交叉熵将为 -log(0.3);如果真实像素为 0,则为 -log(0.7)(smaller 值)真正的像素是 1。
如果真实像素为 1,则完整公式为 -(0*log(0.3) + 1*log(0.7)),否则为 -(1*log(0.3) + 1*log(0.7))。

假设您的目标像素实际上是 0.6!这实质上是说像素有 0.6 的概率开启,0.4 的概率关闭。
这只是将交叉熵计算更改为 -(0.4*log(0.3) + 0.6*log(0.7))。

最后,您可以简单地对图像上的这些每像素交叉熵进行平均/求和。

【讨论】:

    【解决方案2】:

    您通常不应将非二进制类集编码为 0 和 1 之间的值。在 MNIST 的情况下,如果您要标记每个数字 0、0.1、0.2 等,这意味着图像2 的图像比 5 的图像更类似于 0 的图像,这不一定是正确的。

    要做的一件好事是将标签“一次热编码”,作为 0 的 10 元素数组。然后,设置数字图像对应的索引为1。

    如上所述,您将使用常规的交叉熵损失函数。然后,您的模型应该为每个样本输出一个条件概率向量,对应于每个可能的类。可能使用了softmax函数。

    【讨论】:

      【解决方案3】:

      交叉熵损失仅用于分类问题:即您的目标 (yHat) 是离散的。如果您有回归问题,则诸如均方误差 (MSE) 损失之类的东西会更合适。您可以找到 PyTorch 库的各种损失及其实现here

      在 MNIST 数据集的情况下,您实际上有一个多类分类问题(您试图从 10 个可能的数字中预测正确的数字),因此二进制交叉熵损失不适合,您应该而是一般的交叉熵损失。

      无论如何,调查的第一步应该是确定您的问题是“分类”还是“回归”。适用于一个问题的损失函数通常不适用于另一个问题。

      编辑:您可以在 TensorFlow 网站上的 "MNIST for ML Beginners" tutorial 找到有关 MNIST 问题上下文中交叉熵损失的更详细说明。

      【讨论】:

      • 嗯,我主要是在 VAE 的上下文中感到困惑。例如pytorch官方例子中,BCE用于重建损失:github.com/pytorch/examples/blob/master/vae/main.py
      • 对抗性问题不同于传统的 MNIST 分类问题。在对抗性问题中,第二个网络尝试输出第一个网络生成的图像是真实数据的概率。这是一个二元问题:真实与虚假。在传统的分类问题中,(单个)网络试图输出图像对应于任何数字的概率。这是一个多类问题:0 vs 1 vs 2 vs etc.
      • 无论如何,对于真假问题,BCE 是合适的。您编写的函数应该可以正常工作,除了yHat 如果图像是假的,则为 0,如果图像是真实的,则为 1; y 是它是真实的概率(由第二个网络生成)。
      • 我想你误会了,我指的是变分自动编码器,而不是 GAN。
      • 另外,我很想看看源代码,但我在挖掘它时遇到了一些麻烦。如果我没记错的话,nn 模块调用了 F. 模块,而 BCE 又调用了其他东西(C 中的文件?),我在实际挖掘它时遇到了麻烦......
      猜你喜欢
      • 2020-10-09
      • 2021-08-25
      • 2016-08-01
      • 2019-06-19
      • 2017-03-14
      • 2021-11-25
      • 2018-04-14
      • 2020-06-15
      • 1970-01-01
      相关资源
      最近更新 更多