好的,我想通了。
首先,我在 python 中运行以下命令:
>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")
然后加载 Tflite 模型:
>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()
现在我已经详细了解了输入和输出应该是什么样子的
>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
所以输入是转换后的图像 - 形状 300 x 300
>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
现在我已经有了这个模型中多个输出的规范。
我需要改变
interpreter.run(input, output)
到
interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
“输入”在哪里:
private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer
而 map_of_indices_to_outputs 是:
private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);
现在运行后,我在框中得到了 10 个对象的坐标,类中的对象索引(在可可标签文件中)你必须加 1 才能获得正确的键!和分数的概率。
希望这对将来的某人有所帮助。