我正在为此编写更好的文档,但现在这里是我当前草稿中的摘录,可能会有所帮助:
在大多数情况下,使用 TensorFlow 训练模型会为您提供一个文件夹,其中包含一个 GraphDef 文件(通常以 .pb 或 .pbtxt 扩展名结尾)和一组检查点文件。移动或嵌入式部署需要的是单个 GraphDef 文件,该文件已被“冻结”,或者将其变量转换为内联常量,以便所有内容都在一个文件中。
要处理转换,您需要 freeze_graph.py 脚本,该脚本保存在 tensorflow/pythons/tools/freeze_graph.py 中。你会像这样运行它:
bazel build tensorflow/tools:freeze_graph
bazel-bin/tensorflow/tools/freeze_graph \
--input_graph=/tmp/model/my_graph.pb \ --input_checkpoint=/tmp/model/model.ckpt-1000 \ --output_graph=/tmp/frozen_graph.pb \
--input_node_names=input_node \
--output_node_names=output_node \
input_graph 参数应指向保存模型架构的 GraphDef 文件。您的 GraphDef 可能以文本格式存储在磁盘上,在这种情况下,它可能以“.pbtxt”而不是“.pb”结尾,您应该在命令中添加一个额外的--input_binary=false 标志。
input_checkpoint 应该是最近保存的检查点。如检查点部分所述,您需要在此处为检查点集提供公共前缀,而不是完整的文件名。
output_graph 定义生成的冻结 GraphDef 将保存在哪里。因为它可能包含大量占用大量文本格式空间的权重值,所以它总是保存为二进制 protobuf。
output_node_names 是您要从中提取图形结果的节点名称列表。这是必要的,因为冻结过程需要了解图形的哪些部分是实际需要的,以及哪些是训练过程的工件,例如汇总操作。只有有助于计算给定输出节点的操作才会被保留。如果你知道你的图将如何被使用,这些应该只是你传递给 Session::Run() 作为获取目标的节点的名称。如果您手头没有这些信息,您可以通过运行summarize_graph 工具获得一些关于可能输出的建议。
由于 TensorFlow 的输出格式随着时间的推移而发生变化,因此还有许多其他不太常用的标志可用,例如 input_saver,但希望您在使用现代版本的框架训练的图形上不需要这些。