【问题标题】:save and load model in tensorflow 2.0在 tensorflow 2.0 中保存和加载模型
【发布时间】:2020-06-03 14:59:29
【问题描述】:

我使用此代码从 tensorflow 2.x 中的预制估算器中保存了一个模型

import os
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec(my_feature_columns))
estimator_base_path = os.path.join( 'from_estimator')
estimator_path = classifier.export_saved_model(estimator_base_path, serving_input_fn)

此代码创建一个包含 .pb 文件的文件夹 我以后需要重用这个模型,我尝试加载这个函数

saved_model_obj = tf.compat.v2.saved_model.load(export_dir="/model_dir/")

但是当我尝试对使用加载的模型进行预测时,它会引发此错误

predictions = saved_model_obj.predict(
input_fn=lambda: input_fn(predict_x))


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-23-a9902ff8210c> in <module>
----> 1 predictions = saved_model_obj.predict(
      2     input_fn=lambda: input_fn(predict_x))

AttributeError: 'AutoTrackable' object has no attribute 'predict'

我如何加载 .pb 文件并进行预测,就像我从未保存和加载它一样?

【问题讨论】:

标签: python tensorflow machine-learning


【解决方案1】:

当我保存模型以备后用时,我通常会这样做:

假设你的模型是model

model.save('my_model.h5') 

这会将模型保存为 hdf5 格式。

然后当我必须再次使用它来预测时,我可以:

new_model = tf.keras.models.load_model('my_model.h5')

你可以new_model.predict()

【讨论】:

  • 我正在使用预制的估计器,特别是 DNNClassifier。这不是 keras 模型
  • 只要您可以将其加载到模型对象中并进行训练,您就可以将其保存在 hdf5 文件中
  • 使用函数 tf.compat.v2.saved_model.load 加载时,它不是模型对象,而是 Autotrackable 对象
猜你喜欢
  • 2020-05-21
  • 2020-03-16
  • 1970-01-01
  • 1970-01-01
  • 2019-04-12
  • 1970-01-01
  • 1970-01-01
  • 2021-02-21
  • 2021-03-08
相关资源
最近更新 更多