【发布时间】:2019-12-29 17:19:14
【问题描述】:
我正在尝试运行脚本以从 tensorflow .pb 模型中获取文本摘要,例如:
OPS counts:
Squeeze : 1
Softmax : 1
BiasAdd : 1
Placeholder : 1
AvgPool : 1
Reshape : 2
ConcatV2 : 9
MaxPool : 13
Sub : 57
Rsqrt : 57
Relu : 57
Conv2D : 58
Add : 114
Mul : 114
Identity : 231
Const : 298
我总体上正在尝试将 .pb 模型转换为 .coremlmodel 并且正在关注这篇文章:
https://hackernoon.com/integrating-tensorflow-model-in-an-ios-app-cecf30b9068d
从 .pb 模型中获取文本摘要是朝着这一目标迈出的一步。我尝试运行以创建文本摘要的代码如下:
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import time
import operator
import sys
def inspect(model_pb, output_txt_file):
graph_def = graph_pb2.GraphDef()
with open(model_pb, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def)
sess = tf.Session()
OPS = sess.graph.get_operations()
ops_dict = {}
sys.stdout = open(output_txt_file, 'w')
for i, op in enumerate(OPS):
print('---------------------------------------------------------------------------------------------------------------------------------------------')
print("{}: op name = {}, op type = ( {} ), inputs = {}, outputs = {}".format(i, op.name, op.type, ", ".join([x.name for x in op.inputs]), ", ".join([x.name for x in op.outputs])))
print('@input shapes:')
for x in op.inputs:
print("name = {} : {}".format(x.name, x.get_shape()))
print('@output shapes:')
for x in op.outputs:
print("name = {} : {}".format(x.name, x.get_shape()))
if op.type in ops_dict:
ops_dict[op.type] += 1
else:
ops_dict[op.type] = 1
print('---------------------------------------------------------------------------------------------------------------------------------------------')
sorted_ops_count = sorted(ops_dict.items(), key=operator.itemgetter(1))
print('OPS counts:')
for i in sorted_ops_count:
print("{} : {}".format(i[0], i[1]))
if __name__ == "__main__":
"""
Write a summary of the frozen TF graph to a text file.
Summary includes op name, type, input and output names and shapes.
Arguments
----------
- path to the frozen .pb graph
- path to the output .txt file where the summary is written
Usage
----------
python inspect_pb.py frozen.pb text_file.txt
"""
if len(sys.argv) != 3:
raise ValueError("Script expects two arguments. " +
"Usage: python inspect_pb.py /path/to/the/frozen.pb /path/to/the/output/text/file.txt")
inspect(sys.argv[1], sys.argv[2])
我运行了这个命令:
python inspect_pb.py /Users/nikhil.c/Desktop/tensorflowModel.pb text_summary.txt
但我没有收到预期的输出,而是收到以下错误消息:
Traceback (most recent call last):
File "inspect_pb.py", line 58, in <module>
inspect(sys.argv[1], sys.argv[2])
File "inspect_pb.py", line 10, in inspect
graph_def.ParseFromString(f.read())
google.protobuf.message.DecodeError: Error parsing message
而且真的不知道从哪里开始。似乎收到相同错误消息的其他类似问题没有太大意义。我该怎么办?
【问题讨论】:
-
补充一下,我用网站上的示例模型进行了尝试,效果很好,所以我相信它可能与实际模型有关。如果您需要更多信息,请告诉我。
-
您的模型可能是使用不兼容的 TensorFlow 版本创建的,或者它以某种方式损坏。你能在某处分享实际的模型文件吗?
-
即 .pb 模型。我在这篇文章之后从 google firebase 下载了:
-
我可以确认这不适用于 TF 1.14。你是用什么版本的 TF 创建这个模型的?
标签: python parsing tensorflow coreml firebase-mlkit