【发布时间】:2021-12-22 12:50:09
【问题描述】:
我已经按照this 官方教程中提供的步骤使用 TensorFlow API 训练了一个对象检测模型。因此,在整个过程结束时,如the exporting step 中所述,我已将模型保存为以下格式。
my_model/
├─ checkpoint/
├─ saved_model/
└─ pipeline.config
我的问题是,一旦模型被保存为这种格式,我该如何加载它并使用它来进行检测?
我可以使用下面的代码通过训练检查点成功地做到这一点。并且在该点之后(我加载生成最佳结果的检查点)导出模型。
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(PATH_TO_PIPELINE_CONFIG)
model_config = configs['model']
detection_model = model_builder.build(model_config=model_config, is_training=False)
# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(PATH_TO_CKPT).expect_partial()
但是,在生产中,我不打算使用这些检查点。我希望从导出的格式中加载模型。
我尝试了以下命令来加载导出的模型,但我没有运气。它没有返回错误,我可以使用下面的 model 变量进行检测,但是输出(边界框、类、分数)不正确,这让我相信加载中缺少一些步骤过程。
model = tf.saved_model.load(path_to_exported_model)
有什么建议吗?
【问题讨论】:
标签: tensorflow object-detection