这是tf.estimator.export.build_raw_serving_input_receiver_fn() 的示例。
它可以直接粘贴到带有 TF2.x 的笔记本中。希望对您有所帮助。
import tensorflow as tf
checkpoint_dir = "/some/location/to/store/the_model"
input_column = tf.feature_column.numeric_column("x")
# Use a LinearClassifier but this would also work with a custom Estimator
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
# Create a fake dataset with only one feature 'x' and an associated label
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
# The thing is that we must not call raw_input_fn: would result in the error
# "tf.placeholder() is not compatible with eager execution."
# Instead pass raw_input_fn directly to estimator.export_saved_model()
feature_to_tensor = {
# pass some dummy tensor: this is just to get the shapes for the placeholder
# that will be created by build_raw_serving_input_receiver_fn().
# Adjust with the shape of 'x'.
#
'x': tf.constant(0.),
}
raw_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_to_tensor, default_batch_size=None)
export_dir = estimator.export_saved_model(checkpoint_dir, raw_input_fn).decode()
然后您可以检查导出的模型:
!saved_model_cli show --all --dir $export_dir
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['all_class_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 2)
name: head/predictions/Tile:0
outputs['all_classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 2)
name: head/predictions/Tile_1:0
outputs['class_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, 1)
name: head/predictions/ExpandDims:0
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 1)
name: head/predictions/str_classes:0
outputs['logistic'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: head/predictions/logistic:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: linear/linear_model/linear/linear_model/linear/linear_model/weighted_sum:0
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: head/predictions/probabilities:0
Method name is: tensorflow/serving/predict
导出的模型现在可以被另一个进程加载并用于推理:
import tensorflow as tf
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
f(x=tf.constant([-2., 5., -3.]))
{'class_ids': <tf.Tensor: shape=(3, 1), dtype=int64, numpy=
array([[1],
[0],
[1]], dtype=int64)>,
'classes': <tf.Tensor: shape=(3, 1), dtype=string, numpy=
array([[b'1'],
[b'0'],
[b'1']], dtype=object)>,
'all_class_ids': <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[0, 1],
[0, 1],
[0, 1]])>,
...etc...