【问题标题】:Tensorflow lite on Android, How can I define the input and output of 'runForMultipleInputsOutputs' function?Android上的Tensorflow lite,如何定义'runForMultipleInputsOutputs'函数的输入和输出?
【发布时间】:2020-03-20 02:00:18
【问题描述】:

我在 android 上使用 tensorflow lite。但是,runForMulipleInputsOutputs 函数不起作用。

这就是我所做的。

1。制作一个“tfile”,这里是 Colab 的模型来源

from numpy import mean
from numpy import std
from numpy import dstack
from pandas import read_csv
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers import Dropout
from keras.layers.convolutional import Conv1D
from keras.layers.convolutional import MaxPooling1D
from keras.utils import to_categorical
from tensorflow import keras

#make the model
n_timesteps, n_features, n_outputs = 128, 9, 6
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu', input_shape=(n_timesteps,n_features)))
model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
model.add(Dropout(0.5))
model.add(MaxPooling1D(pool_size=2))
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dense(n_outputs, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

#save the model
model.save("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
model2 = keras.models.load_model("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
model2.save('/content/gdrive/My Drive/Train_data/tf_accel_trained_model', save_format="tf")

#convert the model and save the tfile
converter = tf.lite.TFLiteConverter.from_saved_model('/content/gdrive/My Drive/Train_data/tf_accel_trained_model')
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open('/content/gdrive/My Drive/Train_data/converted_model.tflite', 'wb').write(tflite_model)

2。在 Android 的 'build.gradle(Module)' 中添加 tensorflow lite 选项

aaptOptions {
        noCompress "tflite"
        noCompress "lite"
    }

dependencies {
    implementation 'org.tensorflow:tensorflow-lite:+'
}

3。在android上上传模型

tflite = getTfliteInterpreter(modelFile);


private Interpreter getTfliteInterpreter(String modelPath) {
    try {
        return new Interpreter(loadModelFile(MainActivity.this, modelPath));
    }
    catch (Exception e) {
        e.printStackTrace();
    }
    return null;
}


private MappedByteBuffer loadModelFile(Activity activity, String MODEL_FILE) throws IOException {
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

3。进行输入输出,model.runForMultipleInputsOutputs

float[][] inp=new float[128][9];
float[][] out=new float[][]{{0, 0, 0, 0, 0, 0}};

java.util.Map<Integer, Object> outputs = new java.util.HashMap();
outputs.put(0, out);

tflite.runForMultipleInputsOutputs(inp,outputs);

Result) 错误,不知道model.runForMultipleInputsOutputs的正确输入输出是什么

2020-03-19 22:00:45.219 14799-14799/com.example.tensorflowlite E/AndroidRuntime: FATAL EXCEPTION: main
    Process: com.example.tensorflowlite, PID: 14799
    java.lang.NullPointerException: Attempt to invoke virtual method 'void org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(java.lang.Object[], java.util.Map)' on a null     object reference
        at com.example.tensorflowlite.MainActivity$1.onClick(MainActivity.java:93)
        at android.view.View.performClick(View.java:6597)
        at android.view.View.performClickInternal(View.java:6574)
        at android.view.View.access$3100(View.java:778)
        at android.view.View$PerformClick.run(View.java:25885)
        at android.os.Handler.handleCallback(Handler.java:873)
        at android.os.Handler.dispatchMessage(Handler.java:99)
        at android.os.Looper.loop(Looper.java:193)
        at android.app.ActivityThread.main(ActivityThread.java:6669)
        at java.lang.reflect.Method.invoke(Native Method)
        at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:493)
        at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:858)

【问题讨论】:

    标签: android tensorflow


    【解决方案1】:

    我找到了问题所在。

    第一。 不应将 Keras 模型更改为 tensorflow 模型。 直接将 keras 模型转换为 tensorflow lite 模型(tfile)。 这是代码(保存和转换模型)

    import tensorflow as tf
    from tensorflow import keras
    model2 = keras.models.load_model("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
    converter = tf.lite.TFLiteConverter.from_keras_model_file("/content/gdrive/My Drive/Train_data/accel_trained_model.h5")
    tflite_model = converter.convert()
    open('/content/gdrive/My Drive/Train_data/converted_model.tflite', 'wb').write(tflite_model)
    

    第二个。我在android上改变了输入。 您可以在 android.xml 中检查输入和输出的类型。 通过这样做,

    Log.d("Tag", Arrays.toString(input.shape()));
    Log.d("Tag", Arrays.toString(outi.shape()));
    

    在这种情况下,我的输入和输出类型是这样的。

    #input shape Log
    2020-03-20 21:33:59.608 20035-20035/com.example.tensorflowlite D/Tag: [1, 128, 9]
    #output shape Log
    2020-03-20 21:33:59.608 20035-20035/com.example.tensorflowlite D/Tag: [1, 6]
    

    所以,我改变了输入和输出的形状。像这样。

    float[][][] inp=new float[1][128][9];
    float[][] out=new float[][]{{0, 0, 0, 0, 0, 0}};
    
    tflite = getTfliteInterpreter(modelFile);
    tfile.run(inp, out);
    
    private Interpreter getTfliteInterpreter(String modelPath) {
        try {
            return new Interpreter(loadModelFile(MainActivity.this, modelPath));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
    
    
    private MappedByteBuffer loadModelFile(Activity activity, String MODEL_FILE) throws IOException {
        AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_FILE);
        FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
        FileChannel fileChannel = inputStream.getChannel();
        long startOffset = fileDescriptor.getStartOffset();
        long declaredLength = fileDescriptor.getDeclaredLength();
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
    }
    

    然后,效果很好。

    【讨论】:

    • 你从哪里得到对象“input”和“outi”来获得输入/输出形状?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2020-08-09
    • 1970-01-01
    • 2018-10-13
    • 1970-01-01
    • 1970-01-01
    • 2018-11-11
    • 1970-01-01
    相关资源
    最近更新 更多