【问题标题】:Multilabel classification with class imbalance in PytorchPytorch中具有类不平衡的多标签分类
【发布时间】:2020-02-01 00:21:06
【问题描述】:

我有一个多标签分类问题,我正在尝试用 Pytorch 中的 CNN 解决这个问题。我有 80,000 个训练示例和 7900 个课程;每个示例可以同时属于多个类,每个示例的平均类数为 130。

问题是我的数据集非常不平衡。对于某些课程,我只有大约 900 个示例,大约 1%。对于“过度代表”的课程,我有大约 12000 个示例(15%)。当我训练模型时,我使用来自pytorch 的 BCEWithLogitsLoss 和一个正权重参数。我计算权重的方式与文档中描述的相同:负例数除以正例数。

因此,我的模型几乎高估了每个类别……无论是小类还是大类,我得到的预测几乎是真实标签的两倍。而我的 AUPRC 只有 0.18。尽管它比完全不加权要好得多,因为在这种情况下,模型将所有内容都预测为零。

所以我的问题是,如何提高性能?还有什么我可以做的吗?我尝试了不同的批量采样技术(对少数类进行过采样),但它们似乎不起作用。

【问题讨论】:

    标签: pytorch multilabel-classification imbalanced-data


    【解决方案1】:

    我会建议其中一种策略

    焦点损失


    Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr Dollar介绍了一种非常有趣的方法,通过调整损失函数来处理不平衡的训练数据Focal Loss for Dense Object Detection (ICCV 2017)。
    他们建议修改二元交叉熵损失,以减少易于分类示例的损失和梯度,同时“将努力”集中在模型出现严重错误的示例上。

    硬负挖掘

    另一种流行的方法是做“硬负挖掘”;也就是说,只为部分训练示例传播梯度 - “硬”示例。
    参见,例如:
    Abhinav Shrivastava、Abhinav Gupta 和 Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)

    【讨论】:

    • Focal loss 在这种情况下可能不是一个好的选择,有 7900 个类。有太多超参数需要微调。
    • @zihaozhihao 确实很棘手。但我会尝试对所有类使用相同的 gamma。
    【解决方案2】:

    @Shai 提供了两种在深度学习时代发展起来的策略。我想为您提供一些额外的传统机器学习选项:过采样欠采样

    它们的主要思想是在开始训练之前通过采样产生更平衡的数据集。请注意,您可能会面临一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点。

    请参阅wiki link 了解更多信息。

    【讨论】:

      猜你喜欢
      • 2020-09-04
      • 2021-04-19
      • 2020-10-23
      • 2013-12-26
      • 1970-01-01
      • 2017-11-01
      • 2019-08-05
      • 2021-02-26
      • 2017-05-26
      相关资源
      最近更新 更多