【问题标题】:ONNX Runtime Inference | session.run() multiprocessingONNX 运行时推理 | session.run() 多处理
【发布时间】:2022-07-20 20:07:20
【问题描述】:

目标:在多个 CPU 内核上并行运行推理

我正在尝试使用simple_onnxruntime_inference.ipynb 进行推理。

个人:

outputs = session.run([output_name], {input_name: x})

很多:

outputs = session.run(["output1", "output2"], {"input1": indata1, "input2": indata2})

依次:

%%time
outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]

这个 Multiprocessing tutorial 提供了许多并行化任何任务的方法。

但是,我想知道哪种方法最适合session.run(),无论是否通过outputs

如何并行推断所有输出和输入?

代码:

import onnxruntime
import multiprocessing as mp

session = onnxruntime.InferenceSession('bert.opt.quant.onnx')

i = 0
# First Input
input_name = session.get_inputs()[i].name
print("Input Name  :", input_name)

# First Output
output_name = session.get_outputs()[i].name
print("Output Name  :", output_name)  

pool = mp.Pool(mp.cpu_count())

# PARALLELISE THIS LINE
outputs = [session.run([], {input_name: inputs[i]})[0] for i in range(test_data_num)]
# outputs = pool.starmap(func, zip(iter_1, iter_2))

pool.close()

print(results)

更新solution 建议使用 starmap()zip() 来传递函数名称和 2 个单独的可迭代对象。

用这个替换行:

outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))

追溯:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-45-0aab302a55eb> in <module>
     25 #%%time
     26 #outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]
---> 27 outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))
     28 
     29 pool.close()

<ipython-input-45-0aab302a55eb> in <listcomp>(.0)
     25 #%%time
     26 #outputs = [session.run([output_name], {input_name: inputs[i]})[0] for i in range(test_data_num)]
---> 27 outputs = pool.starmap(session.run, zip([output_name], [ {input_name: inputs[i]}[0] for i in range(test_data_num) ]))
     28 
     29 pool.close()

KeyError: 0

【问题讨论】:

    标签: parallel-processing multiprocessing inference onnx onnxruntime


    【解决方案1】:
    def run_inference(i):
        output_name = session.get_outputs()[0].name
        return session.run([output_name], {input_name: inputs[i]})[0]  # [0] bc array in list
    
    outputs = pool.map(run_inference, [i for i in range(test_data_num)])
    

    欢迎大家批评

    【讨论】:

      猜你喜欢
      • 2021-05-23
      • 2018-12-01
      • 1970-01-01
      • 2022-10-16
      • 2020-01-28
      • 2020-09-19
      • 2012-08-20
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多