【问题标题】:How can i use tf.keras.callbacks.ModelCheckpoint in Keras Tuner?如何在 Keras Tuner 中使用 tf.keras.callbacks.ModelCheckpoint?
【发布时间】:2021-11-24 02:20:18
【问题描述】:

所以我想在 Keras Tuner 中使用 tf.keras.callbacks.ModelCheckpoint,但是您选择保存检查点的路径的方式不允许您将其保存为具有特定名称、名称的文件与该检查点的试验和执行相关联,仅与一个 epoch 相关联。

也就是说,如果我简单地将这个回调放在 Keras Tuner 中,在检查点保存发生的那一刻,最后,我将不知道如何将保存的检查点与试验和试验执行相关联,只与 epoch .

【问题讨论】:

    标签: python tensorflow machine-learning keras keras-tuner


    【解决方案1】:

    您可以将tf.keras.callbacks.ModelCheckpoint 用于Keras tuner,就像在其他模型中使用的一样来保存检查点。

    使用this模型搜索得到的超参数训练模型后,可以定义模型检查点并保存如下:

    hypermodel = tuner.hypermodel.build(best_hps)
    
    # Retrain the model
    hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
    
    import os
    checkpoint_path = "training_1/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    
    # Create a callback that saves the model's weights
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)
    history = hypermodel.fit(img_train, label_train, epochs=5, validation_split=0.2, callbacks=[cp_callback])
    os.listdir(checkpoint_dir)
    
    # Re-evaluate the model
    loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
    
    # Loads the weights
    hypermodel.load_weights(checkpoint_path)
    
    # Re-evaluate the model
    loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
    print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
    

    有关保存和加载模型检查点的更多信息,请参阅this 链接。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2020-04-05
      • 2021-12-28
      • 1970-01-01
      • 1970-01-01
      • 2021-05-23
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多