【问题标题】:Converting TF2 Object detection API model to frozen graph将 TF2 对象检测 API 模型转换为冻结图
【发布时间】:2022-01-13 22:04:50
【问题描述】:

我使用 Tensorflow 对象检测 API 训练了模型 ssd_resnet50_v1_fpn_640x640_coco17_tpu-8 https://github.com/tensorflow/models/blob/master/research/object_detection/model_main_tf2.py

在我将其导出到保存模型后: .\exporter_main_v2.py --input_type image_tensor --pipeline_config_path .\models\my_ssd_resnet50_v1_fpn\pipeline.config --trained_checkpoint_dir .\models\my_ssd_resnet50_v1_fpn\ --output_directory .\exported-models\models\Bel_model 使用 https://github.com/tensorflow/models/blob/master/research/object_detection/exporter_main_v2.py

在这一步中,使用 Tensorflow 可以很好地进行推理。来自保存的模型和检查点。此代码用于测试推理: https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip

在我尝试使用这种方法将保存的模型转换为冻结图以在 OpenCV 中使用它之后 https://github.com/opencv/opencv/issues/16879#issuecomment-603815872

import tensorflow as tf
print(tf.__version__)

from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

loaded = tf.saved_model.load('models/mnist_test')
infer = loaded.signatures['serving_default']

f = tf.function(infer).get_concrete_function(flatten_input=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32))
f2 = convert_variables_to_constants_v2(f)
graph_def = f2.graph.as_graph_def()

# Export frozen graph
with tf.io.gfile.GFile('frozen_graph.pb', 'wb') as f:
   f.write(graph_def.SerializeToString())

不幸的是,在这一步我收到错误:

Traceback (most recent call last):
  File ".\frozen_graph.py", line 8, in <module>
    f = tf.function(infer).get_concrete_function(input_1=tf.TensorSpec(shape=[None, 640, 640, 3], dtype=tf.float32))
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 1299, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 1205, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1669 __call__  *
        return self._call_impl(args, kwargs)
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1685 _call_impl  **
        raise structured_err
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1678 _call_impl
        return self._call_with_structured_signature(args, kwargs,
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1756 _call_with_structured_signature
        self._structured_signature_check_missing_args(args, kwargs)
    C:\Users\Bleach\miniconda3\envs\TFstd\lib\site-packages\tensorflow\python\eager\function.py:1775 _structured_signature_check_missing_args
        raise TypeError("{} missing required arguments: {}".format(

    TypeError: signature_wrapper(*, input_tensor) missing required arguments: input_tensor

请帮我解决这个问题。 也许您可以建议我另一种创建冻结图的方法。 使用 Keras 训练模型是否有可能更简单的解决方案?

【问题讨论】:

  • 问题解决了吗?
  • @HakanC,没有。并且github上的问题仍然打开。
  • 嗨@Bleach,感谢您的回复。我找到了一些前进的方法。检查一下:1.创建一个模型实例并加载最后一个检查点并将模型保存为.h5(参考tensorflow.org/tutorials/keras/save_and_load)2.将keras模型转换为冻结模型参考。 medium.com/@sebastingarcaacosta/…(完成到最后 - 优化模型)我有一些新问题。但我认为模型问题已解决。
  • 你把这个shape=[None, 28, 28, 1]改成你输入的形状了吗?

标签: python tensorflow


【解决方案1】:

替换以下3行:

loaded = tf.saved_model.load('models/mnist_test')
infer = loaded.signatures['serving_default']
f = tf.function(infer).get_concrete_function(flatten_input=tf.TensorSpec(shape=[None, 28, 28, 1], dtype=tf.float32))

loaded = keras.models.load_model('models/mnist_test')
f = tf.function(lambda x: loaded(x))
f = f.get_concrete_function(tf.TensorSpec(loaded.inputs[0].shape, loaded.inputs[0].dtype))

【讨论】:

    猜你喜欢
    • 2021-01-15
    • 1970-01-01
    • 1970-01-01
    • 2019-08-05
    • 1970-01-01
    • 2021-09-24
    • 1970-01-01
    • 2019-11-18
    • 1970-01-01
    相关资源
    最近更新 更多