【问题标题】:How to read output from tensorflow model in java如何在java中读取tensorflow模型的输出
【发布时间】:2019-07-09 23:58:57
【问题描述】:

我尝试使用 TensorflowLite 和 ssdlite_mobilenet_v2_coco 模型从 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 转换为 tflite 文件,以检测我的 android 应用程序 (java) 中相机流中的对象。我执行

    interpreter.run(input, output);

其中输入是转换为 ByteBuffer 的图像,输出是浮点数组 - 大小 [1][10][4] 以匹配张量。

如何将此浮点数组转换为一些可读的输出? - 例如获取边界框坐标、对象名称、概率。

【问题讨论】:

  • 从您提供的链接中,您使用的是哪种型号?
  • ssdlite_mobilenet_v2_coco

标签: tensorflow tensorflow-lite


【解决方案1】:

好的,我想通了。 首先,我在 python 中运行以下命令:

>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")

然后加载 Tflite 模型:

>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()

现在我已经详细了解了输入和输出应该是什么样子的

>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

所以输入是转换后的图像 - 形状 300 x 300

>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

现在我已经有了这个模型中多个输出的规范。 我需要改变

interpreter.run(input, output) 

interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);

“输入”在哪里:

private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer 

而 map_of_indices_to_outputs 是:

private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);

现在运行后,我在框中得到了 10 个对象的坐标,类中的对象索引(在可可标签文件中)你必须加 1 才能获得正确的键!和分数的概率。

希望这对将来的某人有所帮助。

【讨论】:

  • 这对我有帮助,我感谢您回来展示您的发现!谢谢!如果您可以提供帮助,仍然有一个问题-我的输入形状是 [1,15000,1],本质上它只是一个浮点数组-(它是 3d 的事实是 Keras 的其他问题)-我如何构造我的输入形状在 JAVA 中以适应预期的形状?
  • 它应该只是一个形状为 [1][15000][1] 的数组。所以也许你可以把你的普通数组转换成这个形状。例如,您输入 [1.444;2.447; 5,678...] 然后你创建了这个数组 [[ [1.444]; [2.447]; [5.678]; ... ]]。现在你应该可以通过了。
  • 感谢您的跟进! tf.contrib.lite.Interpreter 现在是 tf2 中的 tf.lite.Interpreter
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2017-02-23
  • 1970-01-01
  • 2022-12-24
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多