【发布时间】:2019-03-31 20:40:47
【问题描述】:
我已经成功训练了数字分类器。现在我正在尝试在android中使用它。我从未使用过 tensorflow,因此我遵循了一堆教程,并达到了需要在 android 应用程序中使用我创建的 .pb 文件的地步。我正在尝试加载它,但它需要 inputName 和 outputName。我无法弄清楚那会是什么。从 python 脚本中,我认为 outputName 将等于 final_result 但其余的我不知道。这就是我在 Android 中所拥有的
mClassifiers.add(
TensorFlowClassifier.create(
context.getAssets(),
"?????", // <- what goes here ?
"clasifier.pb",
"labels.txt",
100,
"????", // <- what goes here ?
"???", // <- what goes here ?
true)
);
import android.content.res.AssetManager;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class TensorFlowClassifier implements Classifier {
// Only returns if at least this confidence
//must be a classification percetnage greater than this
private static final float THRESHOLD = 0.1f;
private TensorFlowInferenceInterface tfHelper;
private String name;
private String inputName;
private String outputName;
private int inputSize;
private boolean feedKeepProb;
private List<String> labels;
private float[] output;
private String[] outputNames;
//given a saved drawn model, lets read all the classification labels that are
//stored and write them to our in memory labels list
private static List<String> readLabels(AssetManager am, String fileName) throws IOException {
List<String> labels = new ArrayList<>();
BufferedReader br = null;
try {
br = new BufferedReader(new InputStreamReader(am.open(fileName)));
String line;
while ((line = br.readLine()) != null) {
labels.add(line);
}
} catch (Exception e) {
} finally {
if (br != null) {
br.close();
}
}
return labels;
}
//given a model, its label file, and its metadata
//fill out a classifier object with all the necessary
//metadata including output prediction
public static TensorFlowClassifier create(AssetManager assetManager,
String name,
String modelPath,
String labelFile,
int inputSize,
String inputName,
String outputName,
boolean feedKeepProb) throws IOException {
//intialize a classifier
TensorFlowClassifier c = new TensorFlowClassifier();
//store its name, input and output labels
c.name = name;
c.inputName = inputName;
c.outputName = outputName;
//read labels for label file
c.labels = readLabels(assetManager, labelFile);
//set its model path and where the raw asset files are
c.tfHelper = new TensorFlowInferenceInterface(assetManager, modelPath);
int numClasses = 10;
//how big is the input?
c.inputSize = inputSize;
// Pre-allocate buffer.
c.outputNames = new String[] { outputName };
c.outputName = outputName;
c.output = new float[numClasses];
c.feedKeepProb = feedKeepProb;
return c;
}
@Override
public String name() {
return name;
}
@Override
public Classification recognize(final float[] pixels, final int width, final int height) {
//using the interface
//give it the input name, raw pixels from the drawing,
//input size
tfHelper.feed(inputName, pixels, 1, width, height, 1);
//probabilities
if (feedKeepProb) {
tfHelper.feed("keep_prob", new float[] { 1 });
}
//get the possible outputs
tfHelper.run(outputNames);
//get the output
tfHelper.fetch(outputName, output);
// Find the best classification
//for each output prediction
//if its above the threshold for accuracy we predefined
//write it out to the view
Classification ans = new Classification();
for (int i = 0; i < output.length; ++i) {
/*System.out.println(output[i]);
System.out.println(labels.get(i));*/
if (!labels.get(i).equals("0") && output[i] > THRESHOLD && output[i] > ans.getConf()) {
ans.update(output[i], labels.get(i));
}
}
return ans;
}
}
可以在这里找到 python 脚本,因为我无法包含它 https://github.com/MicrocontrollersAndMore/TensorFlow_Tut_2_Classification_Walk-through/blob/master/retrain.py
【问题讨论】:
标签: android python tensorflow