【问题标题】:Getting nan as loss value获取 nan 作为损失值
【发布时间】:2021-03-26 08:04:52
【问题描述】:

我使用paper 在 Pytorch 中实现了焦点损失。并遇到了损失问题 - 将 nan 作为损失函数值。

这是焦点损失的实现:

def focal_loss(y_real, y_pred, gamma = 2):
    y_pred = torch.sigmoid(y_pred)
    return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
                       y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))

我认为训练循环和我的 SegNet 可以正常工作,因为我已经用 dice 和 bce 损失对它们进行了测试。

我认为错误发生在反向传播中。为什么会这样?也许我的实现是错误的?

【问题讨论】:

    标签: deep-learning pytorch loss-function


    【解决方案1】:

    此版本正在运行:

    def focal_loss(y_real, y_pred, eps = 1e-8, gamma = 0):
        probabilities = torch.clamp(torch.sigmoid(y_pred), min=eps, max=1-eps)
        return torch.mean((1 - probabilities)**gamma * 
               (y_pred - y_real * y_pred + torch.log(1 + torch.exp(-y_pred))))
    

    【讨论】:

      【解决方案2】:

      这很可能是由于尝试计算 log(0)。

      我建议像这样更改代码:

      EPS = 1e-9
      def focal_loss(y_real, y_pred, gamma = 2):
          y_pred = torch.sigmoid(y_pred)
          y_pred = torch.clamp(y_pred, EPS, 1. - EPS)
          return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
                             y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))
      

      【讨论】:

      • 谢谢,我已经尝试过这个添加,但我得到了同样的错误
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-03-12
      • 2021-09-06
      • 1970-01-01
      • 2021-03-26
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多