【问题标题】:Multi-class weighted loss for semantic image segmentation in keras/tensorflowkeras/tensorflow中语义图像分割的多类加权损失
【发布时间】:2020-04-18 14:35:24
【问题描述】:

给定批处理 RGB 图像作为输入,shape=(batch_size, width, height, 3)

一个多类目标表示为 one-hot,shape=(batch_size, width, height, n_classes)

还有一个模型(Unet、DeepLab)在最后一层激活了 softmax。

我正在寻找 kera/tensorflow 中的加权分类交叉熵损失函数。

fit_generator 中的class_weight 参数似乎不起作用,我在这里或https://github.com/keras-team/keras/issues/2115 中都没有找到答案。

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        # y_true, y_pred shape is (batch_size, width, height, n_classes)
        loos = ?...
        return loss

    return wcce

【问题讨论】:

  • 多类目标是指考虑了超过 1 个可能的结果吗?
  • “结果”是什么意思? Multiclass=不同的像素值表示不同的类别。你可以有两个以上的课程。 (2 类=二元分类)
  • 多类分类是一种不同类型的分类问题,其中不止一个类是真实的,我对此感到困惑。

标签: tensorflow keras deep-learning semantic-segmentation


【解决方案1】:

我会回答我的问题:

def weighted_categorical_crossentropy(weights):
    # weights = [0.9,0.05,0.04,0.01]
    def wcce(y_true, y_pred):
        Kweights = K.constant(weights)
        if not K.is_tensor(y_pred): y_pred = K.constant(y_pred)
        y_true = K.cast(y_true, y_pred.dtype)
        return K.categorical_crossentropy(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
    return wcce

用法:

loss = weighted_categorical_crossentropy(weights)
optimizer = keras.optimizers.Adam(lr=0.01)
model.compile(optimizer=optimizer, loss=loss)

【讨论】:

    【解决方案2】:

    我正在使用广义骰子损失。在我的情况下,它比加权分类交叉熵更好。我的实现是在 PyTorch 中,但是,它应该很容易翻译。

    class GeneralizedDiceLoss(nn.Module):
        def __init__(self):
            super(GeneralizedDiceLoss, self).__init__()
    
        def forward(self, inp, targ):
            inp = inp.contiguous().permute(0, 2, 3, 1)
            targ = targ.contiguous().permute(0, 2, 3, 1)
    
            w = torch.zeros((targ.shape[-1],))
            w = 1. / (torch.sum(targ, (0, 1, 2))**2 + 1e-9)
    
            numerator = targ * inp
            numerator = w * torch.sum(numerator, (0, 1, 2))
            numerator = torch.sum(numerator)
    
            denominator = targ + inp
            denominator = w * torch.sum(denominator, (0, 1, 2))
            denominator = torch.sum(denominator)
    
            dice = 2. * (numerator + 1e-9) / (denominator + 1e-9)
    
            return 1. - dice
    

    【讨论】:

      【解决方案3】:

      此问题可能类似于:Unbalanced data and weighted cross entropy,其答案已被接受。

      【讨论】:

      • 不,不是。我问的是像素分类。
      猜你喜欢
      • 2017-07-24
      • 2018-10-19
      • 2017-06-26
      • 2018-10-22
      • 2017-12-07
      • 2018-08-07
      • 2020-07-28
      • 2022-10-12
      • 1970-01-01
      相关资源
      最近更新 更多