【问题标题】:Tensorflow: SavedModelBuilder, How to save model with best validation accuracyTensorflow:SavedModelBuilder,如何以最佳验证精度保存模型
【发布时间】:2018-01-30 04:57:20
【问题描述】:

我浏览了 tensorflow 文档,但找不到使用 SavedModelBuilder 类以最佳验证精度保存模型的方法。 我正在使用 tflearn 进行模型构建,下面是我尝试过的工作,但这需要很多时间,我分别为每个时期运行拟合方法并保存模型

for i in range(epoch):
    model.fit(trainX, trainY, n_epoch=1, validation_set=(testX, testY), show_metric=True, batch_size=8)
    builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(i))
    builder.add_meta_graph_and_variables(model.session,
                                     ['TRAINING'],
                                     signature_def_map={
                                         'predict': prediction_sig
                                     })
    builder.save()

如果有更好的方法,请提出建议。

【问题讨论】:

    标签: python tensorflow tensorflow-serving tflearn


    【解决方案1】:

    想通了。它可以通过 tflearn 回调来实现。 谢谢。

    class SaveModelCallback(tflearn.callbacks.Callback):
    def __init__(self, accuracy_threshold):
        self.accuracy_threshold = accuracy_threshold
        self.accuracy = []
        self.max_accuracy = -1
    
    def on_epoch_end(self, training_state):
        self.accuracy.append(training_state.global_acc)
        if training_state.val_acc > self.accuracy_threshold and training_state.val_acc > self.max_accuracy:
            self.max_accuracy = training_state.val_acc
            epoch = training_state.epoch
            self.save_model(epoch)
    
    def save_model(self, epoch):
        print('saved epoch ' + str(epoch))
        builder = tf.saved_model.builder.SavedModelBuilder('/tmp/serving/model/' + str(epoch))
        builder.add_meta_graph_and_variables(model.session,
                                             [tf.saved_model.tag_constants.SERVING],
                                             signature_def_map={
                                                 'predict': prediction_sig,
                                                 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                                                     classification_signature,
                                             })
        builder.save()
    
    callback = SaveModelCallback(accuracy_threshold=0.8)
    model.fit(trainX, trainY, n_epoch=200, validation_set=(testX, testY), show_metric=True, batch_size=8,
              callbacks=callback)
    

    【讨论】:

    • 你知道tflearn是怎么做到的吗?有相关代码吗?我在没有 tflearn 的代码中需要这个。
    • @WeiLiu,我已经编辑了答案,让我知道它是否适合你
    猜你喜欢
    • 2017-01-08
    • 1970-01-01
    • 2015-06-05
    • 1970-01-01
    • 1970-01-01
    • 2019-04-26
    • 2021-11-04
    • 2018-06-27
    • 1970-01-01
    相关资源
    最近更新 更多