【问题标题】:Binary classification - BCELoss and model output size not corresponding二进制分类 - BCELoss 和模型输出大小不对应
【发布时间】:2021-08-09 08:51:12
【问题描述】:

我正在做一个二元分类,因此我使用了二元交叉熵损失:

criterion = torch.nn.BCELoss()

但是,我收到一个错误:

Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64, 2])) is deprecated. Please ensure they have the same size.

我的模型以:

结尾
    x = self.wave_block6(x)
    x = self.sigmoid(self.fc(x))
    return x.squeeze()

我尝试移除挤压,但无济于事。我的批量大小是 64。似乎我在这里做错了一些简单的事情。我的模型是否提供 1 个输出和预期 2 个输入的 BCE 损失?那我应该使用哪个损失?

【问题讨论】:

标签: pytorch


【解决方案1】:

二元交叉熵损失 (BCELoss) 用于二元分类任务。因此,如果 N 是您的批量大小,则模型输出的形状应为 [64, 1],而标签的形状必须为 [64]。因此,只需在第二维压缩输出并将其传递给损失函数 - 这是一个最小的工作示例

import torch
a = torch.randn((64, 1))
b = torch.randn((64))
loss = torch.nn.BCELoss()

b = torch.round(torch.sigmoid(b)) # just to create some labels
a = torch.sigmoid(a).squeeze(1)
l = loss(a, b)

更新 - 基于 cmets 中的对话,focal loss 可以定义如下 -

class focalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=3):
        super(focalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred_logits: torch.Tensor, target: torch.Tensor):
        batch_size = pred_logits.shape[0]
        pred = pred.view(batch_size, -1)
        target = target.view(batch_size, -1)
        pred = pred_logits.sigmoid()
        ce = F.binary_cross_entropy(pred_logits, target, reduction='none')
        alpha = target * self.alpha + (1. - target) * (1. - self.alpha)
        pt = torch.where(target == 1, pred, 1 - pred)
        return alpha * (1. - pt) ** self.gamma * ce

【讨论】:

  • 仍然遇到同样的错误。也许我的 colab 没有从我的本地驱动器刷新代码。我还注意到在训练中这被称为:loss = criteria(output, target.long())
  • 对于 BCELoss,target 和 input 都应该是 float 类型,所以一旦解决了这个错误,你稍后会得到这个错误,请注意,CrossEntropy 损失函数采用 long 类型的目标,所以也许原始代码是为 CrossEntropy loss 编写的
  • @dorien,回到手头的问题,你能告诉我你的标签形状(即你的目标)和你的模型输出形状吗?您在问题中提到它是 `64, 2` 但您使用的是 BCELoss,对于二进制分类问题,您只需要输出一个值,而不是 2
  • 我在原来的loss中使用的是FocalLoss,正在尝试适应BCELoss。让我跑去打印形状(抱歉,这需要一段时间,大数据集),尝试在它们本地运行。
  • 另外,我将添加一个可以直接使用的焦点损失定义
猜你喜欢
  • 2018-06-06
  • 1970-01-01
  • 2020-12-20
  • 2017-12-24
  • 1970-01-01
  • 2021-05-02
  • 1970-01-01
  • 2020-03-12
  • 2021-05-30
相关资源
最近更新 更多