【问题标题】:How to handle class imbalance in multi-label classification using pytorch如何使用pytorch处理多标签分类中的类不平衡
【发布时间】:2020-09-04 20:13:49
【问题描述】:

我们正在尝试在 pytorch 中使用 CNN 实现多标签分类。我们有 8 个标签和大约 260 张图像,使用 90/10 分割训练/验证集。

这些类别高度不平衡,最常见的类别出现在 140 多幅图像中。另一方面,最不频繁的类别出现在少于 5 张图像中。

我们最初尝试了 BCEWithLogitsLoss 函数,该函数导致模型为所有图像预测相同的标签。

然后我们实现了一个焦点损失方法来处理类不平衡,如下所示:

    import torch.nn as nn
    import torch

    class FocalLoss(nn.Module):
        def __init__(self, alpha=1, gamma=2):
            super(FocalLoss, self).__init__()
            self.alpha = alpha
            self.gamma = gamma

        def forward(self, outputs, targets):
            bce_criterion = nn.BCEWithLogitsLoss()
            bce_loss = bce_criterion(outputs, targets)
            pt = torch.exp(-bce_loss)
            focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
            return focal_loss 

这导致模型为每个图像预测空集(无标签),因为它无法获得任何类别的大于 0.5 的置信度。

pytorch 中是否有方法可以帮助解决这种情况?

【问题讨论】:

  • 你试过在BCEWithLogitsLoss中设置pos_weight吗?
  • 5 张图片是一个非常小的样本量。收集更多数据。

标签: machine-learning pytorch multilabel-classification conv-neural-network


【解决方案1】:

基本上有三种处理方法。

  1. 丢弃更常见类中的数据

  2. 权重少数类损失值更重

  3. 过采样少数类

选项 1 是通过选择包含在数据集中的文件来实现的。

选项 2 使用 BCEWithLogitsLosspos_weight 参数实现

选项 3 是通过传递给您的 Dataloader 的自定义 Sampler 实现的

对于深度学习,过采样通常效果最好。

【讨论】:

    猜你喜欢
    • 2013-12-26
    • 2020-02-01
    • 2017-11-03
    • 1970-01-01
    • 2016-06-19
    • 2021-04-19
    • 2019-11-22
    • 2021-06-28
    • 1970-01-01
    相关资源
    最近更新 更多