【问题标题】:Learning problem during training pretrained CNN for multilabel image classification多标签图像分类训练预训练 CNN 期间的学习问题
【发布时间】:2021-07-29 11:46:23
【问题描述】:

我正在尝试训练 CNN 将图像分类为 3 个类别。每个图像可以属于多个类。所以在网络输出中,我期望每个类都有一个概率。

当我进行数据加载时,我有一个带有列的 pandas 数据框:[imageID, class 1, class 2, class 3]。图片尺寸为 (256,256,3),标签为 (3,1)(例如:如果图片属于 1 类,2 类标签为 [1,1,0])

然后,这是我的模型:

print("Define model")
base_model = tf.keras.applications.VGG16(include_top=False, input_shape=(256,256,3),weights='imagenet')
base_model.trainable = True
fine_tune_at = 15

for layer in base_model.layers[:fine_tune_at]:
    layer.trainable =  False


global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(3, activation='sigmoid')
inputs = tf.keras.Input(shape=(256,256,3))
x = base_model(inputs, training=False)

x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)

model = tf.keras.Model(inputs, outputs)


base_learning_rate = 0.00001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), #True
              metrics=['accuracy'])

print("Training")
history = model.fit(train_generator, epochs = 75, validation_data= val_generator)

但我的模型没有学到任何东西,出现了问题,我不知道如何解决。 这是训练曲线:

【问题讨论】:

  • 不应该 training=False 改为 True 吗?

标签: python keras deep-learning tensorflow2.0 multilabel-classification


【解决方案1】:

https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy
只有两个标签类别(假设为 0 和 1)时使用此交叉熵损失。对于每个示例,每个预测都应该有一个浮点值。

y_true = [[0., 1.], [0., 0.]]
y_pred = [[0.6, 0.4], [0.4, 0.6]]

https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy
当有两个或更多标签类别时使用此交叉熵损失函数。我们希望以 one_hot 表示形式提供标签。如果您想以整数形式提供标签,请使用 SparseCategoricalCrossentropy 损失。每个特征应该有 # 个类浮点值。

y_true = [[0, 1, 0], [0, 0, 1]]
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]

【讨论】:

  • OP 说 multilabel 分类,所以使用 BinaryCrossentropy 的网络设置对我来说似乎是正确的。
  • 您好,感谢您的回答。但是我的模型仍然没有学习..
  • @Leili_Kue 你的数据集平衡了吗?如果没有,也许你可以考虑某种 FocalLoss
  • 另外...也许这会有所帮助tensorflow.org/addons/api_docs/python/tfa/losses/…
猜你喜欢
  • 1970-01-01
  • 2019-05-20
  • 1970-01-01
  • 1970-01-01
  • 2018-08-03
  • 2017-07-18
  • 2021-04-04
  • 2021-04-26
  • 1970-01-01
相关资源
最近更新 更多