【发布时间】:2020-07-28 02:21:41
【问题描述】:
我将一个微型 bert 模块转换为 tflite 并使用 tensorflow lite c++ api 运行推理。
当batch size=1时,tensorflow lite平均运行时间为0.6ms,而tensorflow平均运行时间为1ms(默认线程数);当 batch size=10 时,tensorflow lite 的平均运行时间为 5ms,而 tensorflow 的平均运行时间为 3ms。
当我尝试应用 SetNumThreads(4) 时,似乎 tensorflow lite 在多线程加速方面没有做任何事情。
尽管 cpu 使用率从 100% 变为 200%,但 SetNumThreads(4) 和 SetNumThreads(1) 执行相同的运行时。
我想知道这是 tflite 在 X86 桌面上的正常性能吗?
这是我的自定义 tflite c++ 代码的一部分
class Session {
public:
Session() {
model_ = NULL;
interpreter_ = NULL;
}
bool Open(const std::string &saved_model) {
model_ = tflite::FlatBufferModel::BuildFromFile(saved_model.c_str());
if (!model_) {
return false;
}
tflite::InterpreterBuilder(*model_.get(), resolver_)(&interpreter_);
if (!interpreter_) {
return false;
}
interpreter_->SetNumThreads(4);
return true;
}
bool Run(std::vector<int> &dims, int32_t *tok_id, int32_t *msk_id, int32_t *seg_id, float *output) const {
int tok_index = interpreter_->inputs()[2];
int msk_index = interpreter_->inputs()[1];
int seg_index = interpreter_->inputs()[0];
interpreter_->ResizeInputTensor(tok_index, dims);
interpreter_->ResizeInputTensor(msk_index, dims);
interpreter_->ResizeInputTensor(seg_index, dims);
if(interpreter_->AllocateTensors() != kTfLiteOk) //remove AllocateTensors() did not change the runtime
return false;
int32_t bytes = dims[0] * dims[1] * sizeof(int32_t);
int32_t* tok_tensor = interpreter_->typed_tensor<int32_t>(tok_index);
memcpy(tok_tensor, tok_id, bytes);
int32_t* msk_tensor = interpreter_->typed_tensor<int32_t>(msk_index);
memcpy(msk_tensor, msk_id, bytes);
int32_t* seg_tensor = interpreter_->typed_tensor<int32_t>(seg_index);
memcpy(seg_tensor, seg_id, bytes);
if(interpreter_->Invoke() != kTfLiteOk)
return false;
bytes = dims[0] * sizeof(float);
float* result = interpreter_->typed_output_tensor<float>(0);
memcpy(output, result, bytes);
return true;
}
private:
std::unique_ptr<tflite::FlatBufferModel> model_;
std::unique_ptr<tflite::Interpreter> interpreter_;
tflite::ops::builtin::BuiltinOpResolver resolver_;
};
【问题讨论】:
-
慢是什么意思?毕竟,我们正在谈论 m i l l i s e c o n d s...
-
相比于tensorflow c api,tflite 更慢.....
标签: c++ tensorflow-lite