【发布时间】:2019-02-10 11:00:21
【问题描述】:
在尝试将 Keras 模型导出为 TensorFlow Estimator 以提供模型服务时,我遇到了以下问题。由于同样的问题也出现了in an answer to this question,我将说明在一个玩具示例上发生了什么,并提供我的解决方案以用于文档目的。 Tensorflow 1.12.0 和 Keras 2.2.4 会出现此行为。实际的 Keras 以及 tf.keras 都会发生这种情况。
当尝试使用tf.keras.estimator.model_to_estimator 导出从 Keras 模型创建的 Estimator 时会出现问题。调用 estimator.export_savedmodel 时,会抛出 NotFoundError 或 ValueError。
下面的代码为玩具示例重现了这一点。
创建一个 Keras 模型并保存:
import keras
model = keras.Sequential()
model.add(keras.layers.Dense(units=1,
activation='sigmoid',
input_shape=(10, )))
model.compile(loss='binary_crossentropy', optimizer='sgd')
model.save('./model.h5')
接下来,使用tf.keras.estimator.model_to_estimator 将模型转换为估计器,添加输入接收器函数并使用Savedmodel 将其导出为estimator.export_savedmodel 格式:
# Convert keras model to TF estimator
tf_files_path = './tf'
estimator =\
tf.keras.estimator.model_to_estimator(keras_model=model,
model_dir=tf_files_path)
def serving_input_receiver_fn():
return tf.estimator.export.build_raw_serving_input_receiver_fn(
{model.input_names[0]: tf.placeholder(tf.float32, shape=[None, 10])})
# Export the estimator
export_path = './export'
estimator.export_savedmodel(
export_path,
serving_input_receiver_fn=serving_input_receiver_fn())
这会抛出:
ValueError: Couldn't find trained model at ./tf.
【问题讨论】:
标签: python tensorflow keras