【发布时间】:2021-03-26 08:04:52
【问题描述】:
我使用paper 在 Pytorch 中实现了焦点损失。并遇到了损失问题 - 将 nan 作为损失函数值。
这是焦点损失的实现:
def focal_loss(y_real, y_pred, gamma = 2):
y_pred = torch.sigmoid(y_pred)
return -torch.sum((1 - y_pred)**gamma * y_real * torch.log(y_pred) +
y_pred**gamma * (1 - y_real) * torch.log(1 - y_pred))
我认为训练循环和我的 SegNet 可以正常工作,因为我已经用 dice 和 bce 损失对它们进行了测试。
我认为错误发生在反向传播中。为什么会这样?也许我的实现是错误的?
【问题讨论】:
标签: deep-learning pytorch loss-function