【发布时间】:2021-02-25 19:06:23
【问题描述】:
我需要将在自定义数据集上微调的自定义对象检测模型导出到 TensorFlow Lite,以便它可以在 Android 设备上运行。
我在 Ubuntu 18.04 上使用 TensorFlow 2.4.1,到目前为止,这就是我所做的:
- 使用新图像数据集微调了“ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8”模型。我使用了存储库中的“model_main_tf2.py”脚本;
- 我使用“exporter_main_v2.py”导出模型
python exporter_main_v2.py --input_type image_tensor --pipeline_config_path .\models\custom_model\pipeline.config --trained_checkpoint_dir .\models\custom_model\ --output_directory .\exported-models\custom_model
生成了保存的模型(.pb 文件);
3.我测试了导出的模型进行推理,一切正常。在检测例程中,我使用了:
def get_model_detection_function(model):
##Get a tf.function for detection
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
生成的图像对象的形状是 640x640,正如预期的那样。
然后,我尝试将此 .pb 模型转换为 tflite。 在更新到 tensorflow 的 nightly 版本后(使用普通版本,我得到了一个错误),我实际上能够使用以下代码生成一个 .tflite 文件:
import tensorflow as tf
from tflite_support import metadata as _metadata
saved_model_dir = 'exported-models/custom_model/'
## Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
# Save the model.
with open('tflite/custom_model.tflite', 'wb') as f:
f.write(tflite_model)
按照here 给出的说明,我尝试在 AndroidStudio 中使用此模型。
但是,我遇到了几个错误:
- 有关“不是有效的 Tensorflow lite 模型”的内容(必须对此进行更好的检查);
- 错误:
java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (serving_default_input_tensor:0) with 3 bytes from a Java Buffer with 270000 bytes.
第二个错误似乎表明 tflite 模型的预期输入有些奇怪。 我用 Netron 检查了该文件,这就是我得到的:
输入应具有...1x1x1x3 形状,还是我误解了图表? 使用 tflite 导出器时,我应该以某种方式设置张量输入大小吗?
无论如何,导出我的自定义模型以便它可以在 Android 上运行的正确方法是什么?
【问题讨论】:
标签: tensorflow tensorflow2.0 tensorflow-lite object-detection-api