【发布时间】:2020-03-12 02:57:30
【问题描述】:
我的模型有两个输出,我想监控一个来保存我的模型。 以下是我的代码的一部分。 TensorFlow的版本是2.0
model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
metrics={"pitch_yaw_roll": "mae"},
loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
"total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
monitor="val_loss",
verbose=1,
save_freq=save_freq,
save_best_only=True)
ModelCheckpoint回调中默认的monitor='val_loss',如何选择我需要的?我要监听{"pitch_yaw_roll": "mae"}。
【问题讨论】:
-
你想达到什么目的?您只想保存
"pitch_yaw_roll"值最低的纪元吗? -
是的,也许我想每隔几批保存一次最低值对应的模型。正如我所描述的,在
tf.keras.callbacks.ModelCheckpoint中我只能选择monitor = val_loss吗?谢谢你的帮助! @bluesummers
标签: python tensorflow keras tensorflow2.0