【问题标题】:Convert .pb to .tflite for a model of variable input shape将 .pb 转换为 .tflite 以获得可变输入形状的模型
【发布时间】:2021-02-12 16:06:05
【问题描述】:

我正在解决一个问题,我使用自定义数据集使用 Tensorflow 对象检测 API 训练模型。我正在使用 tf 版本 2.2.0

output_directory = 'inference_graph'
!python /content/models/research/object_detection/exporter_main_v2.py \
--trained_checkpoint_dir {model_dir} \
--output_directory {output_directory} \
--pipeline_config_path {pipeline_config_path}

我能够成功获取 .pb 文件以及 .ckpt 文件。但现在我需要将其转换为 .tflite。我无法这样做,有一些错误或其他错误。

我尝试了写在 TensorFlow 文档上的基本方法,但也没有用。 我尝试的另一个代码如下:

    import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, Flatten, MaxPooling2D, Dense, Input, Reshape, Concatenate, GlobalAveragePooling2D, BatchNormalization, Dropout, Activation, GlobalMaxPooling2D
from tensorflow.keras.utils import Sequence

model = tf.saved_model.load(f'/content/drive/MyDrive/FINAL DNET MODEL/inference_graph/saved_model/')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.post_training_quantize=True
converter.inference_type=tf.uint8
tflite_model = converter.convert()
open("val_converted_model_int8.tflite", "wb").write(tflite_model)

我得到的错误是:

AttributeError Traceback(最近调用 最后)在() 8 转换器.post_training_quantize=真 9 转换器.inference_type=tf.uint8 ---> 10 tflite_model = converter.convert() 11 open("val_converted_model_int8.tflite", "wb").write(tflite_model)

/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py 在转换(自我) 837 # 无。 838 # 一旦我们对动态形状有了更好的支持,我们就可以删除它。 --> 839 如果不是实例(self._keras_model.call,_def_function.Function): 840 # Pass keep_original_batch_size=True 将确保我们得到一个输入 841 # 签名包含用户指定的批次维度。

AttributeError: '_UserObject' 对象没有属性 'call'

谁能帮我解决这个问题?

【问题讨论】:

    标签: python tensorflow2.0 tensorflow-lite custom-dataset


    【解决方案1】:

    我认为问题不在于可变输入形状(而错误消息令人困惑)。

    tf.saved_model.load 返回 SavedModel,但 tf.lite.TFLiteConverter.from_keras_model 需要 Keras 模型,因此无法处理。

    您需要使用TFLiteConverter.from_saved_model API。像这样的:

    saved_model_dir = '/content/drive/MyDrive/FINAL DNET MODEL/inference_graph/saved_model/'
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    

    如果您遇到其他问题,请告诉我们。

    【讨论】:

    • 代码运行没有任何错误,但没有保存 tflite 文件。而是给出了这个警告:WARNING:absl:Importing a function (__inference_EfficientDet-D0_layer_call_and_return_conditional_losses_90785) with ops with custom gradients。如果请求渐变,可能会失败。我需要写这个文件还是什么?因为没有输出。我刚刚使用 tf 模型。
    • 如果你运行了model = converter.convert(),但它返回一个没有错误的空字符串,这很可能是一个错误。你能用可重现的步骤提交 Github 问题吗?谢谢!
    • 我说的是你提供的代码:converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)。
    猜你喜欢
    • 2018-09-03
    • 1970-01-01
    • 1970-01-01
    • 2019-05-08
    • 1970-01-01
    • 2021-08-15
    • 1970-01-01
    • 2018-12-09
    • 1970-01-01
    相关资源
    最近更新 更多