【问题标题】:Android Tensorflow IllegalArgumentException ErrorAndroid TensorFlow IllegalArgumentException 错误
【发布时间】:2017-08-22 13:24:48
【问题描述】:

我正在使用 android studio 和 tensorflow,android 版本进行图像识别。 它不是连续跟踪和识别,只是对一张图像的识别。 我已经有图形 pb 和标签 txt 文件,并设置了所需的设置。 但是有一个大问题。 关于图像,尺寸错误,我反复遇到同样的错误。 这是错误日志和我的源代码。

java.lang.IllegalArgumentException: input must be 4-dimensional[1,1,299,299,3]
                                                                         [[Node: ResizeBilinear = ResizeBilinear[T=DT_FLOAT, align_corners=false, _device="/job:localhost/replica:0/task:0/cpu:0"](ExpandDims, ResizeBilinear/size)]]
                                                                         at org.tensorflow.Session.run(Native Method)
                                                                         at org.tensorflow.Session.access$100(Session.java:48)
                                                                         at org.tensorflow.Session$Runner.runHelper(Session.java:295)
                                                                         at org.tensorflow.Session$Runner.run(Session.java:245)
                                                                         at org.tensorflow.contrib.android.TensorFlowInferenceInterface.run(TensorFlowInferenceInterface.java:144)
                                                                         at com.example.yuuuuu.tensorTest.TensorFlowImageClassifier.recognizeImage(TensorFlowImageClassifier.java:119)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity.runTensor(MainActivity.java:69)
                                                                         at com.example.yuuuuu.tensorTest.MainActivity$1.onClick(MainActivity.java:42)
                                                                         at android.view.View.performClick(View.java:6205)
                                                                         at android.widget.TextView.performClick(TextView.java:11103)
                                                                         at android.view.View$PerformClick.run(View.java:23653)
                                                                         at android.os.Handler.handleCallback(Handler.java:751)
                                                                         at android.os.Handler.dispatchMessage(Handler.java:95)
                                                                         at android.os.Looper.loop(Looper.java:154)
                                                                         at android.app.ActivityThread.main(ActivityThread.java:6682)
                                                                         at java.lang.reflect.Method.invoke(Native Method)
                                                                         at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1520)
                                                                         at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1410)

我不知道问题出在哪里,第一行,[1,1,299,299,3]。我认为两个 299 是 ImageSize,一个 1 是 ImageStd,但我不知道另一个 1 和 3 是什么...... 我在 tensorflow github 中输入了与官方代码相同的代码,只是更改了一些部分。 这是 MainActivity。

public class MainActivity extends AppCompatActivity {

private static final String MODEL_FILE = "file:///android_asset/optimized_graph.pb";
private static final String LABEL_FILE = "file:///android_asset/output_labels.txt";
private static final String INPUT_NAME = "Cast";
private static final String OUTPUT_NAME = "final_result";
private static final int INPUT_SIZE = 299;
private static final int IMAGE_MEAN = 117;
private static final float IMAGE_STD = 1;

private Classifier classifier;
private TextView textView;
private ImageView img;
private Button button;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    textView = (TextView)findViewById(R.id.textView);
    button = (Button)findViewById(R.id.btn);
    img = (ImageView)findViewById(R.id.img);

    button.setOnClickListener(new View.OnClickListener(){
        public void onClick(View v){
            runTensor();
        }
    });

    initTensor();
}

public void initTensor(){
    classifier = TensorFlowImageClassifier.create(
            getAssets(),
            MODEL_FILE,
            LABEL_FILE,
            INPUT_SIZE,
            IMAGE_MEAN,
            IMAGE_STD,
            INPUT_NAME,
            OUTPUT_NAME
    );
}

public void runTensor(){
    Bitmap bitmap = BitmapFactory.decodeResource(getResources(), R.drawable.test);
    bitmap = Bitmap.createScaledBitmap(bitmap, INPUT_SIZE, INPUT_SIZE, false);

    img = (ImageView)findViewById(R.id.img);
    img.setImageBitmap(bitmap);

    final List<Classifier.Recognition> results = classifier.recognizeImage(bitmap);
    textView.setText(results.toString());
}

protected void onDestroy(){
    super.onDestroy();
    classifier.close();
}

}

这是分类器,和官方代码一样。

public interface Classifier {

public class Recognition{
    private final String id;
    private final String title;
    private final Float confidence;
    private RectF location;

    public Recognition(
            final String id, final String title, final Float confidence, final RectF location){
        this.id = id;
        this.title = title;
        this.confidence = confidence;
        this.location = location;
    }

    public String getId(){return id;}
    public String getTitle(){return title;}
    public Float getConfidence(){return confidence;}
    public RectF getLocation(){return location;}
    public void setLocation(RectF location){this.location = location;}

    public String toString(){
        String resultString = "";
        if (id != null) {
            resultString += "[" + id + "] ";
        }

        if (title != null) {
            resultString += title + " ";
        }

        if (confidence != null) {
            resultString += String.format("(%.1f%%) ", confidence * 100.0f);
        }

        if (location != null) {
            resultString += location + " ";
        }

        return resultString.trim();
    }
}

List<Recognition> recognizeImage(Bitmap bitmap);
void enableStatLogging(final boolean debug);
String getStatString();
void close();
}

最后是TensorFlowImageClassifier,官方也一样。

public class TensorFlowImageClassifier implements Classifier {
private static final String TAG = "TensorFlowImageClassifier";

private static final int MAX_RESULTS = 3;
private static final float THRESHOLD = 0.1f;

private String inputName;
private String outputName;
private int inputSize;
private int imageMean;
private float imageStd;

private Vector<String> labels = new Vector<String>();
private int[] intValues;
private float[] floatValues;
private float[] outputs;
private String[] outputNames;

private boolean logStats = false;
private TensorFlowInferenceInterface inferenceInterface;
private TensorFlowImageClassifier() {}

/*
assetManager : assets 로드하는데 사용
modelFilename : pb 파일
labelFilename : txt 파일
inputSize : 정사각형 길이, inputSize * inputSize
imageMean : image values 평균값
imageStd : image values 표준값?
inputName : image input 노드 레이블
outputName : output 노드 레이블
 */

public static Classifier create(
        AssetManager assetManager, String modelFilename, String labelFilename, int inputSize, int imageMean, float imageStd, String inputName, String outputName){
    TensorFlowImageClassifier c = new TensorFlowImageClassifier();
    c.inputName = inputName;
    c.outputName = outputName;

    String actualFilename = labelFilename.split("file:///android_asset/")[1];
    Log.d(TAG, "reading labels from : " + actualFilename);
    BufferedReader br = null;

    try {
        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
        String line;
        while((line = br.readLine()) != null){
            c.labels.add(line);
        }
        br.close();
    } catch (IOException e) {
        throw new RuntimeException("failed reading labels" , e);
    }

    c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);

    final Operation operation = c.inferenceInterface.graphOperation(outputName);
    final int numClasses = (int)operation.output(0).shape().size(1);
    Log.d(TAG, "reading " + c.labels.size() + " labels, size of output layers : " + numClasses);

    c.inputSize = inputSize;
    c.imageMean = imageMean;
    c.imageStd = imageStd;

    c.outputNames = new String[]{outputName};
    c.intValues = new int[inputSize * inputSize];
    c.floatValues = new float[inputSize * inputSize * 3];
    c.outputs = new float[numClasses];

    return c;
}

@RequiresApi(api = Build.VERSION_CODES.JELLY_BEAN_MR2)
public List<Recognition> recognizeImage(final Bitmap bitmap){
    beginSection("recognizeImage");
    beginSection("preprocessBitmap");

    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    for(int i = 0; i < intValues.length; i++){
        final int val = intValues[i];
        floatValues[i*3+0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
        floatValues[i*3+1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
        floatValues[i*3+2] = ((val & 0xFF) - imageMean) / imageStd;
    }
    endSection();

    beginSection("feed");
    inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
    endSection();

    beginSection("run");
    inferenceInterface.run(outputNames, logStats);
    endSection();

    beginSection("fetch");
    inferenceInterface.fetch(outputName, outputs);
    endSection();


    PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>(
            3,
            new Comparator<Recognition>(){
                public int compare(Recognition lhs, Recognition rhs){
                    return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                }
            }
    );

    for(int i = 0; i < outputs.length; ++i){
        if(outputs[i] > THRESHOLD){
            pq.add(
                    new Recognition("" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
        }
    }

    final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
    int recognitionSize = Math.min(pq.size(), MAX_RESULTS);
    for(int i = 0; i < recognitionSize; ++i){
        recognitions.add(pq.poll());
    }
    endSection();

    return recognitions;
}

public void enableStatLogging(boolean logStats){this.logStats = logStats;}
public String getStatString(){return inferenceInterface.getStatString();}
public void close(){inferenceInterface.close();}
}

如果您知道如何修复这些代码,请告诉我如何解决。

【问题讨论】:

    标签: android tensorflow


    【解决方案1】:

    java.lang.IllegalArgumentException: 输入必须是 4维[1,1,299,299,3]

    错误消息解释了问题:您不小心传递了 5 项数组而不是 4 项数组。也就是说,您可能应该传递 [1,299,299,1] 而不是 [1,1,299,299,3] 之类的东西。

    很难从您的问题中看出您实际进行了哪些代码更改。如果您可以将更改作为单个 Git 提交进行,那么有人可能更容易确定是哪个更改导致了问题?

    您可以尝试在 TensorBoard 中查看您的 TensorFlow 模型,以检查输入和输出节点以检查它们是否与您配置的值匹配:
    https://medium.com/@daj/how-to-inspect-a-pre-trained-tensorflow-model-5fd2ee79ced0

    【讨论】:

      【解决方案2】:

      好吧,当我使用原生库时,我注意到它们通常不会自己从资产中获取文件,您需要将其复制到可访问的文件存储路径并将绝对路径传递给库。

      您的错误可能来自加载资源。

      【讨论】:

      • 感谢您的回复。我试过了,但我不知道你所说的可访问文件存储到底是什么。你能告诉我它是什么吗?
      • 复制到公共空间,例如 /sdcard/downloads 或任何你想要的东西
      • 谢谢。我会试试的。
      猜你喜欢
      • 1970-01-01
      • 2022-01-15
      • 2017-05-09
      • 1970-01-01
      • 1970-01-01
      • 2023-03-24
      • 1970-01-01
      • 2016-09-08
      相关资源
      最近更新 更多