【发布时间】:2024-05-02 09:30:02
【问题描述】:
我正在为 Android 开发一个实时对象分类应用程序。首先,我使用“keras”创建了一个深度学习模型,并且我已经将训练好的模型保存为“model.h5”文件。我想知道如何在 android 中使用该模型进行图像分类。
【问题讨论】:
标签: android tensorflow deep-learning keras
我正在为 Android 开发一个实时对象分类应用程序。首先,我使用“keras”创建了一个深度学习模型,并且我已经将训练好的模型保存为“model.h5”文件。我想知道如何在 android 中使用该模型进行图像分类。
【问题讨论】:
标签: android tensorflow deep-learning keras
您不能将 Keras 直接导出到 Android,但您必须保存模型
将 TensorFlow 配置为您的 Keras 后端。
使用model.save(filepath) 保存模型(你已经这样做了)
然后使用以下解决方案之一加载它:
解决方案 1:在 Tensflow 中导入模型
1- 构建 TensorFlow 模型
2- 构建 Android 应用并调用 Tensorflow。检查这个tutorial 和这个来自谷歌的official demo 以了解如何做到这一点。
解决方案2:在java中导入模型
1- deeplearning4j 一个允许导入 keras 模型的 java 库:tutorial link
2- 在 Android 中使用 deeplearning4j:这很容易,因为您在 Java 世界中。检查this tutorial
【讨论】:
如果您想优化分类方法,那么我建议您使用 armnn android 库对您的模型进行推理。
您必须遵循几个步骤。 1. 在 ubuntu 中安装和设置 arm nn 库。您可以从以下网址获取帮助
https://github.com/ARM-software/armnn/blob/branches/armnn_19_08/BuildGuideAndroidNDK.md
编译后你会得到二进制文件,它将接受输入并给你输出
您可以在任何安卓应用程序中运行该二进制文件
是优化方式。
【讨论】:
首先您需要将 Keras 模型导出到 Tensorflow 模型:
def export_model_for_mobile(model_name, input_node_names, output_node_name):
tf.train.write_graph(K.get_session().graph_def, 'out', \
model_name + '_graph.pbtxt')
tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')
freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
False, 'out/' + model_name + '.chkp', output_node_name, \
"save/restore_all", "save/Const:0", \
'out/frozen_' + model_name + '.pb', True, "")
input_graph_def = tf.GraphDef()
with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
input_graph_def.ParseFromString(f.read())
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def, input_node_names, [output_node_name],
tf.float32.as_datatype_enum)
with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
您只需要知道图表的input_nodes_names 和output_node_names。这将创建一个包含多个文件的新文件夹。其中,一个以tensorflow_lite_开头。这是您应该移动到 Android 设备的文件。
然后在 Android 上导入 Tensorflow 库并使用 TensorFlowInferenceInterface 运行您的模型。
implementation 'org.tensorflow:tensorflow-android:1.5.0'
你可以在 Github 上查看我的简单异或示例:
【讨论】: