为了这个thread。我会将我的评论改写为答案。
发布完整示例需要 CMake 设置或将文件放入特定目录以运行 bazel。因为我喜欢第一种方式,它会打破这篇文章的所有限制以涵盖所有部分,我想重定向到我为 TF > v1.5 测试过的complete implementation in C99, C++, GO without Bazel。
在 C++ 中加载图表并不比在 Python 中困难多少,鉴于您已经从源代码编译了 TensorFlow。
从创建一个 MWE 开始,它会创建一个非常转储的网络图,这对于弄清楚事情是如何工作的总是一个好主意:
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[1, 2], name='input')
output = tf.identity(tf.layers.dense(x, 1), name='output')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.save(sess, './exported/my_model')
关于这部分的 SO 这里可能有很多答案。所以我就让它留在这里,不再解释。
在 Python 中加载
在用其他语言做东西之前,我们可以尝试在 python 中正确地做——从某种意义上说:我们只需要用 C++ 重写它。
甚至在 python 中恢复也很容易,例如:
import tensorflow as tf
with tf.Session() as sess:
# load the computation graph
loader = tf.train.import_meta_graph('./exported/my_model.meta')
sess.run(tf.global_variables_initializer())
loader = loader.restore(sess, './exported/my_model')
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
它没有帮助,因为大多数这些 API 端点在 C++ API 中不存在(还没有?)。另一个版本是
import tensorflow as tf
with tf.Session() as sess:
metaGraph = tf.train.import_meta_graph('./exported/my_model.meta')
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
sess.run(restore_op, {filename_tensor_name: './exported/my_model'})
x = tf.get_default_graph().get_tensor_by_name('input:0')
output = tf.get_default_graph().get_tensor_by_name('output:0')
等一下。您始终可以使用print(dir(object)) 来获取restore_op_name、...等属性。
与其他操作一样,恢复模型是 TensorFlow 中的一项操作。我们只是调用此操作并提供路径(字符串张量)作为输入。我们甚至可以编写自己的restore 操作
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
即使这看起来很奇怪,但现在在 C++ 中做同样的事情有很大帮助。
在 C++ 中加载
从平常的东西开始
#include <tensorflow/core/public/session.h>
#include <tensorflow/core/public/session_options.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>
#include <string>
#include <iostream>
typedef std::vector<std::pair<std::string, tensorflow::Tensor>> tensor_dict;
int main(int argc, char const *argv[]) {
const std::string graph_fn = "./exported/my_model.meta";
const std::string checkpoint_fn = "./exported/my_model";
// prepare session
tensorflow::Session *sess;
tensorflow::SessionOptions options;
TF_CHECK_OK(tensorflow::NewSession(options, &sess));
// here we will put our loading of the graph and weights
return 0;
}
您应该能够通过将其放入 TensorFlow 存储库并使用 bazel 来编译它,或者只需按照说明 here 使用 CMake。
我们需要创建这样一个由tf.train.import_meta_graph 创建的meta_graph。这可以通过
tensorflow::MetaGraphDef graph_def;
TF_CHECK_OK(ReadBinaryProto(tensorflow::Env::Default(), graph_fn, &graph_def));
在 C++ 中,从文件中读取图形不与在 Python 中导入图形相同。我们需要在会话中通过
TF_CHECK_OK(sess->Create(graph_def.graph_def()));
通过查看上面奇怪的pythonrestore函数:
restore_op_name = metaGraph.as_saver_def().restore_op_name
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name
我们可以用 C++ 编写等效的代码
const std::string restore_op_name = graph_def.saver_def().restore_op_name()
const std::string filename_tensor_name = graph_def.saver_def().filename_tensor_name()
有了这个,我们就可以运行操作了
sess->Run(feed_dict, // inputs
{}, // output_tensor_names (we do not need them)
{restore_op}, // target_node_names
nullptr) // outputs (there are no outputs this time)
创建 feed_dict 本身可能是一个帖子,这个答案已经足够长了。它只涵盖最重要的东西。我想重定向到我针对 TF > v1.5 测试过的complete implementation in C99, C++, GO without Bazel。这并不难——在plain C version 的情况下它可能会变得很长。