【问题标题】:Running TensorFlow on multicore devices在多核设备上运行 TensorFlow
【发布时间】:2017-10-21 17:01:20
【问题描述】:

我有一个基本的 Android TensorFlowInference 示例,可以在单线程中正常运行。

public class InferenceExample {

    private static final String MODEL_FILE = "file:///android_asset/model.pb";
    private static final String INPUT_NODE = "intput_node0";
    private static final String OUTPUT_NODE = "output_node0"; 
    private static final int[] INPUT_SIZE = {1, 8000, 1};
    public static final int CHUNK_SIZE = 8000;
    public static final int STRIDE = 4;
    private static final int NUM_OUTPUT_STATES = 5;

    private static TensorFlowInferenceInterface inferenceInterface;

    public InferenceExample(final Context context) {
        inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
    }

    public float[] run(float[] data) {

        float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES];

        inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);
        inferenceInterface.run(new String[]{OUTPUT_NODE});
        inferenceInterface.fetch(OUTPUT_NODE, res);

        return res;
    }
}

当按照下面的示例在 ThreadPool 中运行时,该示例会崩溃并出现各种异常,包括 java.lang.ArrayIndexOutOfBoundsExceptionjava.lang.NullPointerException,所以我猜它不是线程安全的。

InferenceExample inference = new InferenceExample(context);

ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);    
Collection<Future<?>> futures = new LinkedList<Future<?>>();

for (int i = 1; i <= 100; i++) {
    Future<?> result = executor.submit(new Runnable() {
        public void run() {
           inference.call(randomData);
        }
    });
    futures.add(result);
}

for (Future<?> future:futures) {
    try { future.get(); }
    catch(ExecutionException | InterruptedException e) {
        Log.e("TF", e.getMessage());
    }
}

是否可以通过TensorFlowInferenceInterface 使用多核 Android 设备?

【问题讨论】:

    标签: java android tensorflow


    【解决方案1】:

    为了使InferenceExample 线程安全,我将TensorFlowInferenceInterfacestatic 更改为run 方法synchronized

    private TensorFlowInferenceInterface inferenceInterface;
    
    public InferenceExample(final Context context) {
        inferenceInterface = new TensorFlowInferenceInterface(assets, model);
    }
    
    public synchronized float[] run(float[] data) { ... }
    

    然后我在numThreads 中循环使用InterferenceExample 实例的列表。

    for (int i = 1; i <= 100; i++) {
        final int id = i % numThreads;
        Future<?> result = executor.submit(new Runnable() {
            public void run() {
                list.get(id).run(data);
            }
        });
        futures.add(result);
    }
    

    但这确实会提高性能 在 8 核设备上,此峰值为 2 的 numThreads,并且在 Android Studio Monitor 中仅显示约 50% 的 CPU 使用率。

    【讨论】:

    • 我强烈建议不要使用这种方法。当然,您已经做到了,可以同时调用run,但这只有在您不更改输入时才有意义(通过调用TensorFlowInferenceInterface.feed())。大概,您希望您的线程提供不同的输入,以便计算可以在它们上运行。您建议的方法对此并不安全。
    • 为什么不同的输入不安全?在循环中按“id”顺序存储期货的微小变化,我将知道哪个输入与哪个输出匹配。
    • 哦,抱歉,我看错了,没有注意到对 feed()fetch() 的调用在您同步的 run() 中。所以我在上面的评论中弄错了。但是,您的方法将限制并行性,因为这实质上会序列化 TensorFlow 会话的使用 - 一次只有一个线程可以执行模型。
    • 我看到并行性有限,但Session 不是TensorFlowInferenceInterface 的实例属性,所以我应该让n 会话为n InferenceExamples 并行运行我创造?
    • 我今天显然看不清楚。是的,如果你有nTensorFlowInferenceInterface 对象,那么你就有nSession 对象。不过,您也正在加载模型 n 次,如果您的模型占用空间很大,您将消耗 n 次所需的内存。但这可能与您所看到的并行性限制正交。
    【解决方案2】:

    TensorFlowInferenceInterface 类不是线程安全的(因为它在对 feedrunfetch 等的调用之间保持状态。

    但是,它构建在 TensorFlow Java API 之上,其中 Session 类的对象是线程安全的。

    因此您可能希望直接使用底层 Java API,TensorFlowInferenceInterface 的构造函数创建一个 Session 并使用从 AssetManager (code) 加载的 Graph 进行设置。

    希望对您有所帮助。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2016-11-06
      • 2017-01-17
      • 1970-01-01
      • 1970-01-01
      • 2012-01-05
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多