【问题标题】:Freezing and Exporting a TensorFlow model in Tensorflow 2.0在 TensorFlow 2.0 中冻结和导出 TensorFlow 模型
【发布时间】:2019-08-09 21:24:19
【问题描述】:

我正在尝试将使用 Tensorflow 1.13 编写的现有代码(使用 Estimator)迁移到 Tensorflow 2.0,但在尝试找到等效的 API 以冻结和输出图形并输出 .pb 文件时遇到问题。

在 tensorflow 1.13 中,估计器类有一个函数 export_savedmodel,它接受一个模型路径和一个 serving_input_receiver_fn。我无法设置 serving_input_receiver_fn,因为它似乎需要占位符。然而,当迁移到 Tensorflow 2.0 时,尽管存在相同的 API,但由于急切执行模型设置为默认值,占位符不适用于急切执行模式。

   def export(self):
        self.configure()
        a_shape = (None, None, None, self.IMG_CHANNELS)
        b_shape = tf.TensorShape((None, None, self.IMU_DATA_DIM))
        a = tf.compat.v1.placeholder(tf.float32, a_shape, name="a")
        b = tf.compat.v1.placeholder(tf.float32, b_shape, name='b')
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'a': a,
            'b':b
        })
        return self.modelPath, input_fn

RuntimeError: tf.placeholder() 与急切执行不兼容。

因此,我想问一下,从现有检查点文件中冻结和导出模型以输出 .pb 文件的正确方法是什么?

【问题讨论】:

  • 看看这个issue,他们在函数内部调用模型并将函数导出为签名。

标签: python tensorflow


【解决方案1】:

这是tf.estimator.export.build_raw_serving_input_receiver_fn() 的示例。 它可以直接粘贴到带有 TF2.x 的笔记本中。希望对您有所帮助。

import tensorflow as tf

checkpoint_dir = "/some/location/to/store/the_model"

input_column = tf.feature_column.numeric_column("x")
# Use a LinearClassifier but this would also work with a custom Estimator
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

# Create a fake dataset with only one feature 'x' and an associated label
def input_fn():
    return tf.data.Dataset.from_tensor_slices(
        ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)

# The thing is that we must not call raw_input_fn: would result in the error 
# "tf.placeholder() is not compatible with eager execution."
# Instead pass raw_input_fn directly to estimator.export_saved_model()

feature_to_tensor = {
    # pass some dummy tensor: this is just to get the shapes for the placeholder
    # that will be created by build_raw_serving_input_receiver_fn(). 
    # Adjust with the shape of 'x'.
    # 
    'x': tf.constant(0.),
}
raw_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_to_tensor, default_batch_size=None)
export_dir = estimator.export_saved_model(checkpoint_dir, raw_input_fn).decode()

然后您可以检查导出的模型:

!saved_model_cli show --all --dir $export_dir

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['x'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: Placeholder:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['all_class_ids'] tensor_info:
        dtype: DT_INT32
        shape: (-1, 2)
        name: head/predictions/Tile:0
    outputs['all_classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 2)
        name: head/predictions/Tile_1:0
    outputs['class_ids'] tensor_info:
        dtype: DT_INT64
        shape: (-1, 1)
        name: head/predictions/ExpandDims:0
    outputs['classes'] tensor_info:
        dtype: DT_STRING
        shape: (-1, 1)
        name: head/predictions/str_classes:0
    outputs['logistic'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: head/predictions/logistic:0
    outputs['logits'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: linear/linear_model/linear/linear_model/linear/linear_model/weighted_sum:0
    outputs['probabilities'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: head/predictions/probabilities:0
  Method name is: tensorflow/serving/predict

导出的模型现在可以被另一个进程加载并用于推理:

import tensorflow as tf
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
f(x=tf.constant([-2., 5., -3.]))

{'class_ids': <tf.Tensor: shape=(3, 1), dtype=int64, numpy=
 array([[1],
        [0],
        [1]], dtype=int64)>,
 'classes': <tf.Tensor: shape=(3, 1), dtype=string, numpy=
 array([[b'1'],
        [b'0'],
        [b'1']], dtype=object)>,
 'all_class_ids': <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
 array([[0, 1],
        [0, 1],
        [0, 1]])>,
...etc...

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-01-19
    • 1970-01-01
    • 2019-01-04
    • 1970-01-01
    • 2018-05-02
    • 2021-11-22
    相关资源
    最近更新 更多