【问题标题】:MultiClass Image Segmentation多类图像分割
【发布时间】:2021-03-20 16:16:46
【问题描述】:

我已经训练了一个用于多类图像分割的 Unet 模型。我有四个班级,这是代码:

def Unet(input_shape=(128,128, 3),
                 num_classes=4):
    inputs = Input(shape=input_shape)

    down1 = Conv2D(64, (3, 3), padding='same')(inputs)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1 = Conv2D(64, (3, 3), padding='same')(down1)
    down1 = BatchNormalization()(down1)
    down1 = Activation('relu')(down1)
    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)

    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2 = Conv2D(128, (3, 3), padding='same')(down2)
    down2 = BatchNormalization()(down2)
    down2 = Activation('relu')(down2)
    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)

    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3 = Conv2D(256, (3, 3), padding='same')(down3)
    down3 = BatchNormalization()(down3)
    down3 = Activation('relu')(down3)
    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)

    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4 = Conv2D(512, (3, 3), padding='same')(down4)
    down4 = BatchNormalization()(down4)
    down4 = Activation('relu')(down4)
    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)

    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)
    center = Conv2D(1024, (3, 3), padding='same')(center)
    center = BatchNormalization()(center)
    center = Activation('relu')(center)

    up4 = UpSampling2D((2, 2))(center)
    up4 = concatenate([down4, up4], axis=3)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)
    up4 = Conv2D(512, (3, 3), padding='same')(up4)
    up4 = BatchNormalization()(up4)
    up4 = Activation('relu')(up4)

    up3 = UpSampling2D((2, 2))(up4)
    up3 = concatenate([down3, up3], axis=3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)
    up3 = Conv2D(256, (3, 3), padding='same')(up3)
    up3 = BatchNormalization()(up3)
    up3 = Activation('relu')(up3)

    up2 = UpSampling2D((2, 2))(up3)
    up2 = concatenate([down2, up2], axis=3)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)
    up2 = Conv2D(128, (3, 3), padding='same')(up2)
    up2 = BatchNormalization()(up2)
    up2 = Activation('relu')(up2)

    up1 = UpSampling2D((2, 2))(up2)
    up1 = concatenate([down1, up1], axis=3)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)
    up1 = Conv2D(64, (3, 3), padding='same')(up1)
    up1 = BatchNormalization()(up1)
    up1 = Activation('relu')(up1)

    
    classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1)

    model = Model(inputs=inputs, outputs=classify)
    lr = 1e-4
    model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss="sparse_categorical_crossentropy", metrics=['accuracy'])

    return model

model = Unet()

模型摘要是

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 128, 128, 64) 1792        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_66 (BatchNo (None, 128, 128, 64) 256         conv2d_69[0][0]                  
__________________________________________________________________________________________________
activation_66 (Activation)      (None, 128, 128, 64) 0           batch_normalization_66[0][0]     
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 128, 128, 64) 36928       activation_66[0][0]              
__________________________________________________________________________________________________
batch_normalization_67 (BatchNo (None, 128, 128, 64) 256         conv2d_70[0][0]                  
__________________________________________________________________________________________________
activation_67 (Activation)      (None, 128, 128, 64) 0           batch_normalization_67[0][0]     
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 64)   0           activation_67[0][0]              
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 64, 64, 128)  73856       max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
batch_normalization_68 (BatchNo (None, 64, 64, 128)  512         conv2d_71[0][0]                  
__________________________________________________________________________________________________
activation_68 (Activation)      (None, 64, 64, 128)  0           batch_normalization_68[0][0]     
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 64, 64, 128)  147584      activation_68[0][0]              
__________________________________________________________________________________________________
batch_normalization_69 (BatchNo (None, 64, 64, 128)  512         conv2d_72[0][0]                  
__________________________________________________________________________________________________
activation_69 (Activation)      (None, 64, 64, 128)  0           batch_normalization_69[0][0]     
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 128)  0           activation_69[0][0]              
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 32, 32, 256)  295168      max_pooling2d_13[0][0]           
__________________________________________________________________________________________________
batch_normalization_70 (BatchNo (None, 32, 32, 256)  1024        conv2d_73[0][0]                  
__________________________________________________________________________________________________
activation_70 (Activation)      (None, 32, 32, 256)  0           batch_normalization_70[0][0]     
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 32, 32, 256)  590080      activation_70[0][0]              
__________________________________________________________________________________________________
batch_normalization_71 (BatchNo (None, 32, 32, 256)  1024        conv2d_74[0][0]                  
__________________________________________________________________________________________________
activation_71 (Activation)      (None, 32, 32, 256)  0           batch_normalization_71[0][0]     
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 256)  0           activation_71[0][0]              
__________________________________________________________________________________________________
conv2d_75 (Conv2D)              (None, 16, 16, 512)  1180160     max_pooling2d_14[0][0]           
__________________________________________________________________________________________________
batch_normalization_72 (BatchNo (None, 16, 16, 512)  2048        conv2d_75[0][0]                  
__________________________________________________________________________________________________
activation_72 (Activation)      (None, 16, 16, 512)  0           batch_normalization_72[0][0]     
__________________________________________________________________________________________________
conv2d_76 (Conv2D)              (None, 16, 16, 512)  2359808     activation_72[0][0]              
__________________________________________________________________________________________________
batch_normalization_73 (BatchNo (None, 16, 16, 512)  2048        conv2d_76[0][0]                  
__________________________________________________________________________________________________
activation_73 (Activation)      (None, 16, 16, 512)  0           batch_normalization_73[0][0]     
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 512)    0           activation_73[0][0]              
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 8, 8, 1024)   4719616     max_pooling2d_15[0][0]           
__________________________________________________________________________________________________
batch_normalization_74 (BatchNo (None, 8, 8, 1024)   4096        conv2d_77[0][0]                  
__________________________________________________________________________________________________
activation_74 (Activation)      (None, 8, 8, 1024)   0           batch_normalization_74[0][0]     
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 8, 8, 1024)   9438208     activation_74[0][0]              
__________________________________________________________________________________________________
batch_normalization_75 (BatchNo (None, 8, 8, 1024)   4096        conv2d_78[0][0]                  
__________________________________________________________________________________________________
activation_75 (Activation)      (None, 8, 8, 1024)   0           batch_normalization_75[0][0]     
__________________________________________________________________________________________________
up_sampling2d_12 (UpSampling2D) (None, 16, 16, 1024) 0           activation_75[0][0]              
__________________________________________________________________________________________________
concatenate_12 (Concatenate)    (None, 16, 16, 1536) 0           activation_73[0][0]              
                                                                 up_sampling2d_12[0][0]           
__________________________________________________________________________________________________
conv2d_79 (Conv2D)              (None, 16, 16, 512)  7078400     concatenate_12[0][0]             
__________________________________________________________________________________________________
batch_normalization_76 (BatchNo (None, 16, 16, 512)  2048        conv2d_79[0][0]                  
__________________________________________________________________________________________________
activation_76 (Activation)      (None, 16, 16, 512)  0           batch_normalization_76[0][0]     
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 16, 16, 512)  2359808     activation_76[0][0]              
__________________________________________________________________________________________________
batch_normalization_77 (BatchNo (None, 16, 16, 512)  2048        conv2d_80[0][0]                  
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 16, 16, 512)  0           batch_normalization_77[0][0]     
__________________________________________________________________________________________________
conv2d_81 (Conv2D)              (None, 16, 16, 512)  2359808     activation_77[0][0]              
__________________________________________________________________________________________________
batch_normalization_78 (BatchNo (None, 16, 16, 512)  2048        conv2d_81[0][0]                  
__________________________________________________________________________________________________
activation_78 (Activation)      (None, 16, 16, 512)  0           batch_normalization_78[0][0]     
__________________________________________________________________________________________________
up_sampling2d_13 (UpSampling2D) (None, 32, 32, 512)  0           activation_78[0][0]              
__________________________________________________________________________________________________
concatenate_13 (Concatenate)    (None, 32, 32, 768)  0           activation_71[0][0]              
                                                                 up_sampling2d_13[0][0]           
__________________________________________________________________________________________________
conv2d_82 (Conv2D)              (None, 32, 32, 256)  1769728     concatenate_13[0][0]             
__________________________________________________________________________________________________
batch_normalization_79 (BatchNo (None, 32, 32, 256)  1024        conv2d_82[0][0]                  
__________________________________________________________________________________________________
activation_79 (Activation)      (None, 32, 32, 256)  0           batch_normalization_79[0][0]     
__________________________________________________________________________________________________
conv2d_83 (Conv2D)              (None, 32, 32, 256)  590080      activation_79[0][0]              
__________________________________________________________________________________________________
batch_normalization_80 (BatchNo (None, 32, 32, 256)  1024        conv2d_83[0][0]                  
__________________________________________________________________________________________________
activation_80 (Activation)      (None, 32, 32, 256)  0           batch_normalization_80[0][0]     
__________________________________________________________________________________________________
conv2d_84 (Conv2D)              (None, 32, 32, 256)  590080      activation_80[0][0]              
__________________________________________________________________________________________________
batch_normalization_81 (BatchNo (None, 32, 32, 256)  1024        conv2d_84[0][0]                  
__________________________________________________________________________________________________
activation_81 (Activation)      (None, 32, 32, 256)  0           batch_normalization_81[0][0]     
__________________________________________________________________________________________________
up_sampling2d_14 (UpSampling2D) (None, 64, 64, 256)  0           activation_81[0][0]              
__________________________________________________________________________________________________
concatenate_14 (Concatenate)    (None, 64, 64, 384)  0           activation_69[0][0]              
                                                                 up_sampling2d_14[0][0]           
__________________________________________________________________________________________________
conv2d_85 (Conv2D)              (None, 64, 64, 128)  442496      concatenate_14[0][0]             
__________________________________________________________________________________________________
batch_normalization_82 (BatchNo (None, 64, 64, 128)  512         conv2d_85[0][0]                  
__________________________________________________________________________________________________
activation_82 (Activation)      (None, 64, 64, 128)  0           batch_normalization_82[0][0]     
__________________________________________________________________________________________________
conv2d_86 (Conv2D)              (None, 64, 64, 128)  147584      activation_82[0][0]              
__________________________________________________________________________________________________
batch_normalization_83 (BatchNo (None, 64, 64, 128)  512         conv2d_86[0][0]                  
__________________________________________________________________________________________________
activation_83 (Activation)      (None, 64, 64, 128)  0           batch_normalization_83[0][0]     
__________________________________________________________________________________________________
conv2d_87 (Conv2D)              (None, 64, 64, 128)  147584      activation_83[0][0]              
__________________________________________________________________________________________________
batch_normalization_84 (BatchNo (None, 64, 64, 128)  512         conv2d_87[0][0]                  
__________________________________________________________________________________________________
activation_84 (Activation)      (None, 64, 64, 128)  0           batch_normalization_84[0][0]     
__________________________________________________________________________________________________
up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0           activation_84[0][0]              
__________________________________________________________________________________________________
concatenate_15 (Concatenate)    (None, 128, 128, 192 0           activation_67[0][0]              
                                                                 up_sampling2d_15[0][0]           
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 128, 128, 64) 110656      concatenate_15[0][0]             
__________________________________________________________________________________________________
batch_normalization_85 (BatchNo (None, 128, 128, 64) 256         conv2d_88[0][0]                  
__________________________________________________________________________________________________
activation_85 (Activation)      (None, 128, 128, 64) 0           batch_normalization_85[0][0]     
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 128, 128, 64) 36928       activation_85[0][0]              
__________________________________________________________________________________________________
batch_normalization_86 (BatchNo (None, 128, 128, 64) 256         conv2d_89[0][0]                  
__________________________________________________________________________________________________
activation_86 (Activation)      (None, 128, 128, 64) 0           batch_normalization_86[0][0]     
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 128, 128, 64) 36928       activation_86[0][0]              
__________________________________________________________________________________________________
batch_normalization_87 (BatchNo (None, 128, 128, 64) 256         conv2d_90[0][0]                  
__________________________________________________________________________________________________
activation_87 (Activation)      (None, 128, 128, 64) 0           batch_normalization_87[0][0]     
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 128, 128, 4)  260         activation_87[0][0]              
==================================================================================================
Total params: 34,540,932
Trainable params: 34,527,236
Non-trainable params: 13,696

我用我的图像和蒙版图像训练了模型。我的问题是当模型预测测试图像时如何识别类号。例如,模型预测了一个测试图像“imageabc1.png”,我必须找到不同的 4 类掩码属于该图像。我该怎么做?

【问题讨论】:

    标签: deep-learning computer-vision image-segmentation unity3d-unet


    【解决方案1】:

    当您使用 softmax 时,类号对应于网络输出中的通道。假设您预测了一张训练图像:

    pred = model.predict(training_image)
    

    第一类被发现为pred[:,:,:,0], 第二类被发现为pred[:,:,:,1], 等等

    【讨论】:

    • 如果我使用 'sigmoid' 而不是 'softmax',会有什么变化?
    • softmax 激活给出了特定像素所属类别的概率分布。通常,softmax 用于多类,而 sigmoid 用于单类。 Sigmoid 将仅保证您的输出在 0 和 1 之间,但它不具有确保所有类的总和为 1 的 softmax 属性。这可能会导致您的输出之一为 1 类和 2 类分配- 在哪种情况下是哪一个? (除非您的数据不是互斥的,即一个像素可以同时属于两个类别)。 Softmax 分配一个概率,你取最大的。
    猜你喜欢
    • 2020-04-08
    • 2020-10-12
    • 2020-03-15
    • 2022-10-17
    • 2019-08-09
    • 2021-02-26
    • 2015-12-18
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多