【问题标题】:ModelCheckpoint monitoring values when the model has multiple outputs模型有多个输出时的 ModelCheckpoint 监控值
【发布时间】:2020-03-12 02:57:30
【问题描述】:

我的模型有两个输出,我想监控一个来保存我的模型。 以下是我的代码的一部分。 TensorFlow的版本是2.0

model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
              metrics={"pitch_yaw_roll": "mae"},
              loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
                    "total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
              loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_loss",
                                   verbose=1,
                                   save_freq=save_freq,
                                   save_best_only=True)

ModelCheckpoint回调中默认的monitor='val_loss',如何选择我需要的?我要监听{"pitch_yaw_roll": "mae"}

【问题讨论】:

  • 你想达到什么目的?您只想保存"pitch_yaw_roll" 值最低的纪元吗?
  • 是的,也许我想每隔几批保存一次最低值对应的模型。正如我所描述的,在tf.keras.callbacks.ModelCheckpoint 中我只能选择monitor = val_loss 吗?谢谢你的帮助! @bluesummers

标签: python tensorflow keras tensorflow2.0


【解决方案1】:

如果您希望 ModelCheckpoint 根据另一个指标值保存,请使用 .compile(metrics={...}, ...) 指标字典中该指标的键。

例如,如果您只想保存最好的"pitch_yaw_roll" epoch 结果(最好是最小值),您应该使用

tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_pitch_yaw_roll",
                                   verbose=1,
                                   mode="min",
                                   save_freq=save_freq,
                                   save_best_only=True)

如果您选择"pitch_yaw_roll" 而不是"val_pitch_yaw_roll",它将根据训练损失而不是根据验证损失进行保存

【讨论】:

  • 按照你的意思,我知道如何根据loss来保存,那如果我想用validation set的mae代替validation set的loss呢?就像monitor="val_picth_yaw_roll_mae?
  • 就像我写的一样,不要包含mae,只需使用val_pitch_yaw_roll - 因为这个关键点指向mae,所以要监控的是mae
【解决方案2】:

只是在上面添加评论,我相信您的检查点不起作用,因为要监控的值名称不正确。 一般而言,这里的解决方案可能是让您的配合创造的历史达到顶峰。

history = model.fit(...)
pd.DataFrame(history.history)

您将在此处找到应在监控语句中使用的指标名称。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2022-10-14
    • 1970-01-01
    • 1970-01-01
    • 2023-03-16
    • 2022-01-08
    • 2019-08-19
    • 1970-01-01
    • 2021-06-14
    相关资源
    最近更新 更多