【问题标题】:Implementation of Focal loss for multi label classification多标签分类Focal loss的实现
【发布时间】:2019-12-29 07:16:10
【问题描述】:

尝试为多标签分类编写焦点损失

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        self._gamma = gamma
        self._alpha = alpha

    def forward(self, y_true, y_pred):
        cross_entropy_loss = torch.nn.BCELoss(y_true, y_pred)
        p_t = ((y_true * y_pred) +
               ((1 - y_true) * (1 - y_pred)))
        modulating_factor = 1.0
        if self._gamma:
            modulating_factor = torch.pow(1.0 - p_t, self._gamma)
        alpha_weight_factor = 1.0
        if self._alpha is not None:
            alpha_weight_factor = (y_true * self._alpha +
                                   (1 - y_true) * (1 - self._alpha))
        focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
                                    cross_entropy_loss)
        return focal_cross_entropy_loss.mean()

但是当我运行它时,我得到了

  File "train.py", line 82, in <module>
    loss = loss_fn(output, target)
  File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 538, in __call__
    for hook in self._forward_pre_hooks.values():
  File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
    type(self).__name__, name))
AttributeError: 'FocalLoss' object has no attribute '_forward_pre_hooks'

任何建议都会非常有帮助,在此先感谢。

【问题讨论】:

    标签: deep-learning pytorch


    【解决方案1】:

    您不应该从 torch.nn.Module 继承,因为它是为具有可学习参数的模块(例如神经网络)而设计的。

    只需创建普通的仿函数或函数就可以了。

    顺便说一句。如果你继承它,你应该在你的__init__() 的某个地方调用super().__init__()

    【讨论】:

      猜你喜欢
      • 2020-12-19
      • 2020-10-26
      • 1970-01-01
      • 2021-02-14
      • 2019-12-24
      • 2019-04-01
      • 2019-07-10
      • 2016-02-13
      相关资源
      最近更新 更多