【发布时间】:2016-10-06 16:38:58
【问题描述】:
我用input_shape=[125,100,100,1] 训练了一个模型来预测 8 个浮点数。我在演示中更改了these options 以适应我的模型设置。
然后我在批量大小中添加了另一个选项
private static final int BATCH_SIZE = 125;
在the C++ side 中,我打印了一些调试信息以查看张量的形状:
LOG (INFO) << "input node: " << input_tensors[0].first << ", "
<< "input shape: " << input_tensors[0].second.shape().DebugString();
tensorflow_inference_jni.cc:198 输入节点:input_node,输入形状: [125,100,100,1]
但应用程序在调用 vars->session->Run() 函数时崩溃
A/libc: Fatal signal 6 (SIGABRT), code -6 in tid 16574 (InferenceThread)
现在,如果我设置 BATCH_SIZE = 1(始终使用以 125 批大小训练的模型)应用程序不会崩溃,但它会返回此错误:
E/native: tensorflow_inference_jni.cc:213 Error during inference: Invalid argument: Input to reshape is a tensor with 8 values, but the requested shape has 1000
[[Node: output_node = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](fullyconnected2_1/BiasAdd, output_node/shape)]]
此错误中请求的形状 1000 是 num_output * batch_size 我猜是 (8 * 125)。
我错过了什么吗?
【问题讨论】:
-
你是在主线程上执行推理吗?
-
不,有一个background thread 用于推断
标签: android c++ tensorflow