【问题标题】:Tensorflow inference on android producing garbage resultandroid产生垃圾结果的Tensorflow推断
【发布时间】:2020-09-11 23:57:31
【问题描述】:

我使用 tensorflow 训练了一个模型,然后将其转换为 tensorflow-lite 格式。

模型推理在使用 python 的笔记本电脑上运行得非常好。 然后我将模型放入 Android 应用程序并使用 tensorflowlite 解释器进行推理,结果只是一张全黑的图像。 我将python中的代码原样移植到Java中,仍然得到这个垃圾结果。

知道我可能会在哪里出错。

Python 代码:

def preprocess(img):
    return (img / 255. - 0.5) * 2

def deprocess(img):
    return (img + 1) / 2

img_size = 256

frozen_model_filename = os.path.join('model/tflite', 'model.tflite')

image_1 = cv2.resize(imread(image_1), (img_size, img_size))
X_1 = np.expand_dims(preprocess(image_1), 0)
X_1 = X_1.astype(np.float32)

image_2 = cv2.resize(imread(image_2), (img_size, img_size))
X_2 = np.expand_dims(preprocess(image_2), 0)
X_2 = X_2.astype(np.float32)


interpreter = tf.lite.Interpreter(model_path=frozen_model_filename)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


interpreter.set_tensor(input_details[0]['index'], X_1)
interpreter.set_tensor(input_details[1]['index'], X_2)
interpreter.invoke()

Output = interpreter.get_tensor(output_details[0]['index'])
Output = deprocess(Output)

imsave('result_tflite.jpg', Output[0])

Android平台对应的Java代码:

private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {

    Bitmap resized = Bitmap.createScaledBitmap(bitmap, IMAGE_SIZE, IMAGE_SIZE, false);

    ByteBuffer byteBuffer;

    if(isQuant) {
        byteBuffer = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBuffer.order(ByteOrder.nativeOrder());

    int[] intValues = new int[IMAGE_SIZE * IMAGE_SIZE];
    resized.getPixels(intValues, 0, resized.getWidth(), 0, 0, resized.getWidth(), resized.getHeight());

    int pixel = 0;
    byteBuffer.rewind();

    for (int i = 0; i < IMAGE_SIZE; ++i) {
        for (int j = 0; j < IMAGE_SIZE; ++j) {
            final int val = intValues[pixel++];
            if(isQuant){
                byteBuffer.put((byte) ((val >> 16) & 0xFF));
                byteBuffer.put((byte) ((val >> 8) & 0xFF));
                byteBuffer.put((byte) (val & 0xFF));
            } else {
                byteBuffer.putFloat((((val >> 16) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val >> 8) & 0xFF) - 0.5f) * 2.0f);
                byteBuffer.putFloat((((val) & 0xFF ) - 0.5f) *  2.0f);
            }
        }
    }
    return byteBuffer;
}

private Bitmap getOutputImage(ByteBuffer output){
    output.rewind();

    int outputWidth = IMAGE_SIZE;
    int outputHeight = IMAGE_SIZE;
    Bitmap bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888);
    int [] pixels = new int[outputWidth * outputHeight];
    for (int i = 0; i < outputWidth * outputHeight; i++) {
        int a = 0xFF;

        float r = (output.getFloat() + 1) / 2.0f;
        float g = (output.getFloat() + 1) / 2.0f;
        float b = (output.getFloat() + 1) / 2.0f;

        pixels[i] = a << 24 | ((int) r << 16) | ((int) g << 8) | (int) b;
    }
    bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight);
    return bitmap;
}

private void runInference(){

    ByteBuffer byteBufferX1 = convertBitmapToByteBuffer(bitmap_x1);
    ByteBuffer byteBufferX2 = convertBitmapToByteBuffer(bitmap_x2);

    Object[] inputs = {byteBufferX1, byteBufferX2};

    ByteBuffer byteBufferOutput;

    if(isQuant) {
        byteBufferOutput = ByteBuffer.allocateDirect(BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    } else {
        byteBufferOutput = ByteBuffer.allocateDirect(4 * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE * PIXEL_SIZE);
    }

    byteBufferOutput.order(ByteOrder.nativeOrder());
    byteBufferOutput.rewind();

    Map<Integer, Object> outputs = new HashMap<>();
    outputs.put(0, byteBufferOutput);

    interpreter.runForMultipleInputsOutputs(inputs, outputs);
    ByteBuffer out = (ByteBuffer) outputs.get(0);
    Bitmap outputBitmap = getOutputImage(out);

    // outputBitmap is just a full black image
}

【问题讨论】:

    标签: android tensorflow tensorflow-lite


    【解决方案1】:

    Java 和 Python 解释器都是基于 C++ 实现的,所以结果应该是一样的。错误应该在您的 JAVA 代码中。 这里我想你忘了乘除到 255。

    【讨论】:

    • 您好 Thaink,感谢您的回复;我已经尝试过你的建议,产生乱码输出。
    • 您好 Kurnar,通常情况下,您可以通过在将数据输入模型之前记录数据来调试代码,以查看它们在两种情况下是否相同。
    猜你喜欢
    • 2020-12-20
    • 2015-05-24
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-04-20
    • 1970-01-01
    • 1970-01-01
    • 2016-12-31
    相关资源
    最近更新 更多