【发布时间】:2020-09-04 20:13:49
【问题描述】:
我们正在尝试在 pytorch 中使用 CNN 实现多标签分类。我们有 8 个标签和大约 260 张图像,使用 90/10 分割训练/验证集。
这些类别高度不平衡,最常见的类别出现在 140 多幅图像中。另一方面,最不频繁的类别出现在少于 5 张图像中。
我们最初尝试了 BCEWithLogitsLoss 函数,该函数导致模型为所有图像预测相同的标签。
然后我们实现了一个焦点损失方法来处理类不平衡,如下所示:
import torch.nn as nn
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, outputs, targets):
bce_criterion = nn.BCEWithLogitsLoss()
bce_loss = bce_criterion(outputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss
这导致模型为每个图像预测空集(无标签),因为它无法获得任何类别的大于 0.5 的置信度。
pytorch 中是否有方法可以帮助解决这种情况?
【问题讨论】:
-
你试过在
BCEWithLogitsLoss中设置pos_weight吗? -
5 张图片是一个非常小的样本量。收集更多数据。
标签: machine-learning pytorch multilabel-classification conv-neural-network