【问题标题】:Multi label classification with unbalanced labels带有不平衡标签的多标签分类
【发布时间】:2021-04-19 10:21:57
【问题描述】:

我正在构建多标签分类网络。 我的 GT 是长度向量 512 [0,0,0,1,0,1,0,...,0,0,0,1] 大多数时候它们是zeroes,每个向量大约有5 ones,其余的都是零。

我正在考虑这样做:

使用sigmoid 激活输出层。

使用binary_crossentropy 作为损失函数。

但是我该如何解决不平衡问题呢? 网络可以学习预测always zeros,但学习损失分数仍然很低。

我怎样才能让它真正学会预测...

【问题讨论】:

  • @ivan 这是一个完全不同的问题。

标签: python tensorflow neural-network pytorch multilabel-classification


【解决方案1】:

你不能轻易上采样,因为这是一个多标签案例(我最初从帖子中遗漏了)。

你可以做的是给1更高的权重,像这样:

import torch


class BCEWithLogitsLossWeighted(torch.nn.Module):
    def __init__(self, weight, *args, **kwargs):
        super().__init__()
        # Notice none reduction
        self.bce = torch.nn.BCEWithLogitsLoss(*args, **kwargs, reduction="none")
        self.weight = weight

    def forward(self, logits, labels):
        loss = self.bce(logits, labels)
        binary_labels = labels.bool()
        loss[binary_labels] *= labels[binary_labels] * self.weight
        # Or any other reduction
        return torch.mean(loss)


loss = BCEWithLogitsLossWeighted(50)
logits = torch.randn(64, 512)
labels = torch.randint(0, 2, size=(64, 512)).float()

print(loss(logits, labels))

您也可以使用FocalLoss 专注于正面示例(某些库中应该有一些实现)。

编辑:

Focal Loss 也可以按照这些方式进行编码(功能形式,因为这就是我在 repo 中的内容,但你应该能够从中工作):

def binary_focal_loss(
    outputs: torch.Tensor,
    targets: torch.Tensor,
    gamma: float,
    weight=None,
    pos_weight=None,
    reduction: typing.Callable[[torch.Tensor], torch.Tensor] = None,
) -> torch.Tensor:

    probabilities = (1 - torch.sigmoid(outputs)) ** gamma
    loss = probabilities * torch.nn.functional.binary_cross_entropy_with_logits(
        outputs,
        targets.float(),
        weight,
        reduction="none",
        pos_weight=pos_weight,
    )

    return reduction(loss)

【讨论】:

    猜你喜欢
    • 2020-02-01
    • 1970-01-01
    • 2017-11-01
    • 2013-12-26
    • 2020-10-23
    • 1970-01-01
    • 2017-08-26
    • 2018-06-11
    • 2021-02-26
    相关资源
    最近更新 更多