【问题标题】:Running trained tensorflow model in C++在 C++ 中运行经过训练的张量流模型
【发布时间】:2017-12-15 02:15:17
【问题描述】:

我已经使用 tensorflow 在 python 中训练了一个图像分类网络。训练好的模型保存为.pb。现在,我想测试模型,我需要在 C++ 中完成。

我曾使用numpy 来操作和处理数据。在训练阶段,图像作为 numpy 数组传入。图像被拉伸为一维数组,并且类标签被添加到这个数组中。

我对如何在 C++ 中运行模型时传递图像数据感到困惑,因为我无法使用 numpy。我使用numpy 操作来操作和处理数据。如果我必须在 C++ 中执行,我应该以什么格式传递数据。

以下是我如何训练和保存我的模型

def trainModel(data):
    global_step = tf.Variable(0, name='global_step', trainable=False)
    X, y,keep_prob = modelInputs((741, 620, 1),4)
    logits = cnnModel(X,keep_prob)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost")
    optimizer = tf.train.AdamOptimizer(.0001, name='Adam').minimize(cost)
    prediction = tf.argmax(logits, 1, name="prediction")
    correct_pred = tf.equal(prediction, tf.argmax(y, 1), name="correct_pred")
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        batch_size = 30
        for e in range(11):
            batch_x, batch_y = data.next_batch(batch_size)
            batch_y = batch_y.astype('int32')
            x = np.reshape(batch_x, [batch_size, 741, 620, 1])
            labels = np.zeros(shape=(batch_size,4))
            labels[np.arange(len(labels)),batch_y]=1
            sess.run(optimizer, feed_dict={X: x, y: labels,keep_prob:0.5})
            global_step.assign(e).eval()
        saver.save(sess, './my_test_model',global_step=global_step)

*741x620 是图片的大小

【问题讨论】:

  • Tensorflow 确实可以轻松构建如此复杂的网络,您甚至不知道自己拥有什么。您计划如何从 C++ 运行您的 NN?手卷网络或图书馆?注意:图书馆推荐在这里是题外话
  • @MSalters 我打算像this 那样做。我不确定解决方案是否可能涉及图书馆推荐。我想用 C++ 调用模型,但是当我无法访问 numpy 时,我不知道如何传递数据。
  • 好的,这很重要。你打算仍然使用 tensorflow。这留下了如何将输入插入到 tensorflow 中的问题。
  • @MSalters 完全正确。

标签: c++ tensorflow


【解决方案1】:

在 C++ 中使用图形的说明可以在 here 找到。

下面是一些使用图像作为输入的代码:

tensorflow::Tensor keep_prob = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape());
keep_prob.scalar<float>()() = 1.0;

tensorflow::Tensor input_tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1,height,width,depth}));
auto input_tensor_mapped = input_tensor.tensor<float, 4>();
const float * source_data = (float*) img.data;  // here img is an opencv image, but if it's just a float array this code is very easy to adapt
// copying the image data into the corresponding tensor
for (int y = 0; y < height; ++y) {
    const float* source_row = source_data + (y * width * depth);
    for (int x = 0; x < width; ++x) {
        const float* source_pixel = source_row + (x * depth);
        for (int c = 0; c < depth; ++c) {
            const float* source_value = source_pixel + c;
            input_tensor_mapped(0, y, x, c) = *source_value;
        }
    }
}
std::vector<tensorflow::Tensor> finalOutput;

tensorflow::Status run_status = this->tf_session->Run({{InputName,input_tensor}, 
                                                       {dropoutPlaceHolderName, keep_prob}},
                                                      {OutputName},
                                                      {},
                                                      &finalOutput);

【讨论】:

  • 当我评估 logits 节点时如何打印finalOutput 的内容(将是一个有 4 个条目的张量,有 4 个类)。我使用auto output_c = finalOutput[0].scalar&lt;float&gt;(); std::cout &lt;&lt; output_c() &lt;&lt; "\n";这个来打印预测节点的输出,其中输出只是一个数字。
  • 我在网上找到了这段代码,不记得 xhan 也不记得为什么我们需要展平,但它对我有用:` tensorflow::TTypes::Flat indices_flat = finalOutput[1] .flat(); for (int i = 0; i
【解决方案2】:

您可以使用 C++ API,如上一个答案所示,但是,使用 TensorFlow C++ API 进行编译可能会让人头疼。我建议您使用cppflow,它是一个简单易用的 de C API 包装器。它允许您将数据作为 C++ 向量提供给网络:

Model m("mymodel.pb");
m.restore("./my_test_model");

auto X = new Tensor(m, "X");
auto y = new Tensor(m, "y");
auto keep_prob = new Tensor(m, "keep_prob");
auto result = new Tensor(m, "prediction");

std::vector<float> xdata, ydata;
// Fill the vectors with data
X->set_data(xdata);
y->set_data(ydata);

m.run({X,y,keep_prob}, result);

std::vector<float> myresult = result->get_data<float>();

您无需安装完整的 TensorFlow 即可使用此包装器,只需下载 C API 的 .so 即可。

【讨论】:

  • 这是否适用于 saver.save() 生成的“.meta”格式而不是“.pb”格式?
猜你喜欢
  • 2018-08-07
  • 1970-01-01
  • 2017-09-22
  • 2020-02-29
  • 2021-12-24
  • 2023-03-12
  • 1970-01-01
  • 2016-10-05
  • 2020-02-27
相关资源
最近更新 更多