【问题标题】:Keras U-Net weighted loss implementationKeras U-Net 加权损失实现
【发布时间】:2020-01-27 17:57:13
【问题描述】:

我正在尝试分离关闭的对象,如 U-Net 论文 (here) 中所示。为此,可以生成可用于像素损失的权重图。以下代码描述了我在this 博客文章中使用的网络。

x_train_val = # list of images (imgs, 256, 256, 3)
y_train_val = # list of masks (imgs, 256, 256, 1)
y_weights = # list of weight maps (imgs, 256, 256, 1) according to the blog post 
# visual inspection confirms the correct calculation of these maps

# Blog posts' loss function
def my_loss(target, output):
    return - tf.reduce_sum(target * output,
                           len(output.get_shape()) - 1)

# Standard Unet model from blog post
_epsilon = tf.convert_to_tensor(K.epsilon(), np.float32)

def make_weighted_loss_unet(input_shape, n_classes):
    ip = L.Input(shape=input_shape)
    weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))

    conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(ip)
    conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    conv1 = L.Dropout(0.1)(conv1)
    mpool1 = L.MaxPool2D()(conv1)

    conv2 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool1)
    conv2 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    conv2 = L.Dropout(0.2)(conv2)
    mpool2 = L.MaxPool2D()(conv2)

    conv3 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool2)
    conv3 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    conv3 = L.Dropout(0.3)(conv3)
    mpool3 = L.MaxPool2D()(conv3)

    conv4 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool3)
    conv4 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    conv4 = L.Dropout(0.4)(conv4)
    mpool4 = L.MaxPool2D()(conv4)

    conv5 = L.Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(mpool4)
    conv5 = L.Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    conv5 = L.Dropout(0.5)(conv5)

    up6 = L.Conv2DTranspose(512, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv5)
    conv6 = L.Concatenate()([up6, conv4])
    conv6 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    conv6 = L.Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
    conv6 = L.Dropout(0.4)(conv6)

    up7 = L.Conv2DTranspose(256, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv6)
    conv7 = L.Concatenate()([up7, conv3])
    conv7 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    conv7 = L.Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)
    conv7 = L.Dropout(0.3)(conv7)

    up8 = L.Conv2DTranspose(128, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv7)
    conv8 = L.Concatenate()([up8, conv2])
    conv8 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
    conv8 = L.Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)
    conv8 = L.Dropout(0.2)(conv8)

    up9 = L.Conv2DTranspose(64, 2, strides=2, kernel_initializer='he_normal', padding='same')(conv8)
    conv9 = L.Concatenate()([up9, conv1])
    conv9 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = L.Dropout(0.1)(conv9)

    c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)

    # Mimic crossentropy loss
    c11 = L.Lambda(lambda x: x / tf.reduce_sum(x, len(x.get_shape()) - 1, True))(c10)
    c11 = L.Lambda(lambda x: tf.clip_by_value(x, _epsilon, 1. - _epsilon))(c11)
    c11 = L.Lambda(lambda x: K.log(x))(c11)
    weighted_sm = L.multiply([c11, weight_ip])

    model = Model(inputs=[ip, weight_ip], outputs=[weighted_sm])
    return model

然后我编译并拟合模型,如下所示:

model = make_weighted_loss_unet((256, 256, 3), 1) # shape of input, number of classes
model.compile(optimizer='adam', loss=my_loss, metrics=['acc'])
model.fit([x_train_val, y_weights], y_train_val, validation_split=0.1, epochs=1)

然后模型可以像往常一样训练。但是,损失似乎并没有太大改善。此外,当我尝试预测新图像时,我显然没有权重图(因为它们是在标记的掩码上计算的)。我尝试使用形状像权重图的空/零数组,但这只会产生空白/零预测。我还尝试了不同的指标和更多的标准损失,但没有任何成功。

在实施这种加权损失时,是否有人面临同样的问题或有其他选择?提前致谢。烤羊肉

【问题讨论】:

    标签: python keras deep-learning image-segmentation


    【解决方案1】:

    使用像素权重编写自定义损失的更简单方法

    在您的代码中,损失分散在my_lossmake_weighted_loss_unet 函数之间。您可以添加目标作为输入并使用model.add_loss 更好地构建代码:

    def make_weighted_loss_unet(input_shape, n_classes):
        ip = L.Input(shape=input_shape)
        weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))
        targets   = L.input(shape=input_shape[:2] + (n_classes,))
        # .... rest of your model definition code ...
    
        c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)
        model.add_loss(pixel_weighted_cross_entropy(weights_ip, targets, c10))
        # .... return Model .... NO NEED to specify loss in model.compile
    
    def pixel_weighted_cross_entropy(weights, targets, predictions)
        loss_val = keras.losses.categorical_crossentropy(targets, predictions)
        weighted_loss_val = weights * loss_val
        return K.mean(weighted_loss_val)
    

    如果您不将代码重构为上述方法,下一节将展示如何在没有问题的情况下仍然运行推理

    如何在推理中运行您的模型

    选项 1:使用另一个 Model 对象进行推理

    您可以创建一个Model 用于训练,另一个用于推理。除了推断Model 不采用weights_ip 并给出早期输出c10 之外,两者基本相同。

    这是一个示例代码,它添加了一个参数 is_training=True 来决定返回哪个 Model

    def make_weighted_loss_unet(input_shape, n_classes, is_training=True):
        ip = L.Input(shape=input_shape)
    
        conv1 = L.Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(ip)
        # .... rest of your model definition code ...
        c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal')(conv9)
    
        if is_training:
            # Mimic crossentropy loss
            c11 = L.Lambda(lambda x: x / tf.reduce_sum(x, len(x.get_shape()) - 1, True))(c10)
            c11 = L.Lambda(lambda x: tf.clip_by_value(x, _epsilon, 1. - _epsilon))(c11)
            c11 = L.Lambda(lambda x: K.log(x))(c11)
            weight_ip = L.Input(shape=input_shape[:2] + (n_classes,))
            weighted_sm = L.multiply([c11, weight_ip])
            return Model(inputs=[ip, weight_ip], outputs=[weighted_sm])
        else:
            return Model(inputs=[ip], outputs=[c10]) 
        return model
    

    选项 2:使用K.function

    如果您不想弄乱您的模型定义方法 (make_weighted_loss_unet) 并希望在外部获得相同的结果,您可以使用提取与推理相关的子图的函数。

    在你的推理函数中:

    from keras import backend as K
    
    model = make_weighted_loss_unet(input_shape, n_classes)
    inference_function = K.function([model.get_layer("input_layer").input], 
                                    [model.get_layer("output_softmax_layer").output])
    predicted_heatmap = inference_function(new_image)
    

    请注意,您必须将 name= 提供给您的 ip 层和 c10 层才能通过 model.get_layer(name) 检索它们:

    ip = L.Input(shape=input_shape, name="input_layer")
    

    c10 = L.Conv2D(n_classes, 1, activation='softmax', kernel_initializer='he_normal', name="output_softmax_layer")(conv9)
    

    【讨论】:

    • 非常感谢您的回答。除了这种逐像素损失的实现——你会推荐这样做还是有另一种更简单的方法?
    • 我再次尝试训练。虽然我现在能够预测一些事情(哇哦),但损失保持不变,我只是得到了空白的预测。在没有像素损失的情况下使用相同的数据进行训练显示出典型的对数损失。这可能是由于错误的损失函数造成的吗?
    • 关于激活/丢失,我不明白一件事。你有y_train_val = (imgs, 256, 256, 1)n_classes=1,对吧?所以 c10 有 1 个通道。我不认为 Softmax 是应该在这里使用的激活,因为无论值如何,它都会为所有像素输出 1.0。您应该改用 sigmoid。使用 Softmax 进行二进制分类 IFF 你有一个长度为 2 的向量输出(每个像素的热一编码向量,正类向量 = [1,0] 和负类 = [0,1])
    • 是的,我的“课程”只有正面/负面。上述博文中的模型有一个 softmax 激活。我将提供 sigmoid 版本。我主要担心的是上面的损失函数没有计算加权像素损失
    • 我会尝试用你更简单的版本来训练。这使得代码更有意义。我是否正确地假设您的 targets 输入指的是地面实况掩码?
    猜你喜欢
    • 2019-08-08
    • 2019-05-27
    • 2020-12-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-11-11
    • 1970-01-01
    相关资源
    最近更新 更多