【问题标题】:How to implement pixelwise weighting of loss function in Keras?如何在 Keras 中实现损失函数的像素加权?
【发布时间】:2019-05-27 03:07:48
【问题描述】:

我正在使用 Keras 2.2.4,我正在尝试实现像素分类的损失函数,如here 中所述,但我遇到了这里提出的一些困难。我正在做 3D 分割,因此我的目标向量是(b_size, width_x, width_y, width_z, nb_classes)。我实现了以下损失函数,其中权重图与目标和预测向量的形状相同:

def dice_xent_loss(y_true, y_pred, weight_map):

    """Adaptation of https://arxiv.org/pdf/1809.10486.pdf for multilabel 
    classification with overlapping pixels between classes. Dec 2018.
    """
    loss_dice = weighted_dice(y_true, y_pred, weight_map)
    loss_xent = weighted_binary_crossentropy(y_true, y_pred, weight_map)

    return loss_dice + loss_xent

def weighted_binary_crossentropy(y_true, y_pred, weight_map):
    return tf.reduce_mean((K.binary_crossentropy(y_true, 
                                                 y_pred)*weight_map)) / (tf.reduce_sum(weight_map) + K.epsilon())

def weighted_dice(y_true, y_pred, weight_map):

    if weight_map is None:
        raise ValueError("Weight map cannot be None")
    if y_true.shape != weight_map.shape:
        raise ValueError("Weight map must be the same size as target vector")

    dice_numerator = 2.0 * K.sum(y_pred * y_true * weight_map, axis=[1,2,3])
    dice_denominator = K.sum(weight_map * y_true, axis=[1,2,3]) + \
                                                             K.sum(y_pred * weight_map, axis=[1,2,3])
    loss_dice = (dice_numerator) / (dice_denominator + K.epsilon())
    h1=tf.square(tf.minimum(0.1,loss_dice)*10-1)
    h2=tf.square(tf.minimum(0.01,loss_dice)*100-1)
    return 1.0 - tf.reduce_mean(loss_dice) + \
            tf.reduce_mean(h1)*10 + \
            tf.reduce_mean(h2)*10

我正在按照建议使用sample_weights=temporal 编译模型,并将权重作为sample_weight=weights 传递给model.fit。我仍然收到以下错误:

File "overfit_one_case.py", line 153, in <module>
    main()
File "overfit_one_case.py", line 81, in main
   sample_weight_mode="temporal")
 File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training.py", line 342, in compile
sample_weight, mask)
File "/home/igt/anaconda2/envs/niftynet/lib/python2.7/site-packages/keras/engine/training_utils.py", line 404, in weighted
score_array = fn(y_true, y_pred)
TypeError: dice_xent_loss() takes exactly 3 arguments (2 given)

training_utils.py Keras 调用我的自定义损失,没有任何权重。关于如何解决这个问题的任何想法?我的另一个限制是我正在尝试对这个特定模型进行迁移学习。因此,我无法按照建议的hereweight_map 添加到Input 层。

【问题讨论】:

    标签: tensorflow keras deep-learning


    【解决方案1】:

    样本权重是样本的权重,而不是像素的权重。

    除了y_truey_pred,Keras 的损失从不接受任何其他参数。所有 keras 加权都是自动的。

    对于自定义权重,您需要自己实现。您可以将这些损失函数包装在一个带权重的函数中,如下所示:

    def weighted_dice_xent_loss(weight_map):
    
        def dice_xent_loss(y_true, y_pred):
            #code...    
            return loss_dice + loss_xent
        return dice_xent_loss
    
    def weighted_binary_crossentropy(weight_map):
        def inner_binary_crossentropy(y_true, y_pred):
            return tf.reduce_mean(
               (K.binary_crossentropy(y_true, y_pred)*weight_map)) / (
                tf.reduce_sum(weight_map) + K.epsilon())
            return inner_binnary_crossentropy
    
    def weighted_dice(weight_map):
        def dice(y_true, y_pred):
    
        #code....
            return 1.0 - tf.reduce_mean(loss_dice) + \
                tf.reduce_mean(h1)*10 + \
                tf.reduce_mean(h2)*10
       return dice
    

    例如,将它们用作loss=weighted_dice_xent_loss(weight_map)


    使用样本权重的丑陋解决方法。

    如果您的权重对于每个样本都是唯一的,那么您必须将每个像素变成一个样本(这很不寻常)。

    使用您的数据:

    • 扁平化数据的第一个维度,例如 (b_size * width_x * width_y * width_z, nb_channels)
    • 以同样的方式展平您的体重矩阵。
    • 以同样的方式展平你的真实输出

    使用您的模型:

    • 创建兼容的`inputs = Input((nb_channels,))
    • Lambda 层中使用K.reshape 重塑以恢复原始尺寸:K.reshape(x, (-1, width_x, width_y, width_z, nb_classes))
    • 照常制作模型的其余部分
    • 使用K.reshape(x, (-1, nb_classes)) 重塑Lambda 层中的输出

    你的损失:

    • 计算每个像素的损失,不要对像素求和。
    • Keras 权重将在您计算损失后求和(因此与骰子不兼容)

    【讨论】:

      猜你喜欢
      • 2020-07-28
      • 2019-08-08
      • 2020-01-27
      • 2019-08-31
      • 2017-07-24
      • 2018-10-19
      • 1970-01-01
      • 1970-01-01
      • 2018-02-24
      相关资源
      最近更新 更多