【问题标题】:Pytorch Categorical Cross Entropy loss function behaviourPytorch 分类交叉熵损失函数行为
【发布时间】:2020-03-14 08:17:09
【问题描述】:

我对 Pytorch 的分类交叉熵损失的计算有疑问。 我编写了这个简单的代码 sn-p 并且因为我使用输出张量的 argmax 作为目标,所以我无法理解为什么损失仍然很高。

import torch
import torch.nn as nn
ce_loss = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
targets = torch.argmax(output, dim=1)
loss = ce_loss(outputs, targets)
print(loss)

感谢您帮助理解它。 最好的祝福 杰罗姆

【问题讨论】:

  • 高是什么意思?查看我的答案,了解如何计算损失。

标签: pytorch


【解决方案1】:

因此,这是来自您的代码的示例数据,其中 outputlabelloss 具有以下值

outputs =  tensor([[ 0.5968, -0.8249,  1.5018,  2.7888, -0.6125],
                   [-1.1534, -0.4921,  1.0688,  0.2241, -0.0257],
                   [ 0.3747,  0.8957,  0.0816,  0.0745,  0.2695]], requires_grad=True)requires_grad=True)

labels = tensor([3, 2, 1])
loss = tensor(0.7354, grad_fn=<NllLossBackward>)

让我们检查一下这些值,

如果你计算你的 logits (outputs) 的 softmax 输出,使用类似 torch.softmax(outputs,axis=1) 的东西你会得到

probs = tensor([[0.0771, 0.0186, 0.1907, 0.6906, 0.0230],
                [0.0520, 0.1008, 0.4801, 0.2063, 0.1607],
                [0.1972, 0.3321, 0.1471, 0.1461, 0.1775]], grad_fn=<SoftmaxBackward>)

所以这些将是您的预测概率。

现在交叉熵损失只不过是softmaxnegative log likelihood loss. 的组合因此,您的损失可以简单地使用

loss = (torch.log(1/probs[0,3]) +  torch.log(1/probs[1,2]) + torch.log(1/probs[2,1])) / 3

,它是真实标签概率的负对数的平均值。上面的等式计算结果为0.7354,相当于从nn.CrossEntropyLoss 模块返回的值。

【讨论】:

  • 感谢您的解释,但首先我迷失了 argmax 返回:我在想 argmax 将返回最大值的索引。因此,当我采用输出张量时,我希望 argmqx 返回 [1, 3, 4]...
  • 并且损失不会是0,除非您以1的概率预测正确的标签。如果您查看标签为3的第一个实例,我们预测该类具有0.6906 的可移植性,这意味着我们不能 100% 确定,因此我们会遭受一些损失。
  • 啊,好吧,我在想它使用 0.5 的阈值来表示要预测哪个类。现在它是有道理的,它使用没有阈值的概率!非常感谢。 :)
  • 如果您的问题得到了回答,您能接受这个答案吗?谢谢。
  • @Jérôme MASSOT 如果它在损失函数中使用阈值,那么它将无法区分。从某种意义上说,每个损失函数都需要“模糊”,即使是与完美的微小偏差也会导致损失增加。
猜你喜欢
  • 2021-08-25
  • 2018-04-14
  • 2017-03-14
  • 2019-11-02
  • 2017-12-08
  • 2021-01-21
  • 2022-01-09
  • 2020-08-13
  • 2020-12-18
相关资源
最近更新 更多