【发布时间】:2018-04-28 15:28:29
【问题描述】:
我正在尝试在 Android 应用中运行一个简单的虹膜分类器。我在 keras 中创建了一个 MLP,将其转换为 .pb 格式并将其放入 assets 文件夹中。 keras 模型:
data = sklearn.datasets.load_iris()
x=data.data
y=data.target
x=np.array(x)
y=np.array(y)
x_train, x_test, y_train, y_test= train_test_split(x, y, test_size = 0.25)
inputs=Input(shape=(4,))
x=Dense(10,activation="relu",name="input_layer")(inputs)
x=Dense(10,activation="relu")(x)
x=Dense(15,activation="relu")(x)
x=Dense(3,activation="softmax",name="output_layer")(x)
model=Model(inputs,x)
sgd = SGD(lr=0.05, momentum=0.9, decay=0.0001, nesterov=False)
model.compile(optimizer=sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.fit(x_train,y_train,batch_size=20, epochs=100, verbose=0)
AndroidStudio 中的代码(我有 4 个输入数字字段、1 个输出字段和 1 个按钮。当按钮被点击时调用 predictClick 方法):
static{
System.loadLibrary("tensorflow_inference");
}
String model_name ="file:///android_asset/iris_model.pb";
String output_name = "output_layer";
String input_name = "input_data";
TensorFlowInferenceInterface tfinterface;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
tfinterface = new TensorFlowInferenceInterface(getAssets(),model_name);
}
public void predictClick(View v){
TextView antwort = (TextView)findViewById(R.id.antwort);
EditText number1 = (EditText)findViewById(R.id.number1);
EditText number2 = (EditText)findViewById(R.id.number2);
EditText number3 = (EditText)findViewById(R.id.number3);
EditText number4 = (EditText)findViewById(R.id.number4);
Button predict = (Button)findViewById(R.id.button);
float[] result = {};
String zahl1 = number1.getText().toString();
String zahl2 = number2.getText().toString();
String zahl3 = number3.getText().toString();
String zahl4 = number4.getText().toString();
float n1 = Float.parseFloat(zahl1);
float n2 = Float.parseFloat(zahl2);
float n3 = Float.parseFloat(zahl3);
float n4 = Float.parseFloat(zahl4);
float[] inputs={n1,n2,n3,n4};
//im pretty sure these lines cause the error
tfinterface.feed(input_name,inputs,4,1);
tfinterface.run(new String[]{output_name});
tfinterface.fetch(output_name,result);
antwort.setText(Float.toString(result[0]));
}
构建运行没有错误,但是当我点击预测按钮时,应用程序崩溃。当我离开台词时
tfinterface.feed(input_name,inputs,4,1);
tfinterface.run(new String[]{output_name});
tfinterface.fetch(output_name,result);
应用程序运行正常,所以我认为这是错误的来源。
【问题讨论】:
-
请从 LogCat 中查找并包含与该崩溃相关的异常。否则无法为您提供帮助。
-
请记住,您的代码不是“在 Android Studio 中”运行的。那只是一个构建代码的文本编辑器。代码在实际的 Android 设备上运行
标签: android tensorflow