【问题标题】:Error when parsing graph_def from string从字符串解析 graph_def 时出错
【发布时间】:2018-03-14 19:28:55
【问题描述】:

我正在尝试运行一个非常简单的 Tensorflow 图保存为 .pb 文件,但在解析它时出现此错误:

Traceback (most recent call last):
  File "test_import_stripped_bm.py", line 28, in <module>
    graph_def.ParseFromString(fileContent)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1069, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
    if value._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1105, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 743, in DecodeMap
    if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/python_message.py", line 1095, in InternalParse
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 799, in _SkipGroup
    new_pos = SkipField(buffer, pos, end, tag_bytes)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/usr/local/lib/python3.5/dist-packages/google/protobuf/internal/decoder.py", line 814, in _SkipFixed32
    raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.

这是我用来将其写入 .pb 的代码:

import tensorflow as tf

builder = tf.saved_model.builder.SavedModelBuilder('models/TEST-3')

w1 = tf.Variable(tf.random_normal((2,2)), name="w1")
w2 = tf.Variable(tf.random_normal((2,2)), name="w2")

sess = tf.Session()
sess.run(tf.global_variables_initializer())

builder.add_meta_graph_and_variables(sess, tags=[tf.saved_model.tag_constants.SERVING], clear_devices = True)

builder.save()
sess.close()

这是解析它的代码:

import tensorflow as tf
import os

model_path = os.path.join('models/TEST-3', 'saved_model.pb')
with open(model_path, mode='rb') as f:
    fileContent = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

要查看我必须做的确切错误

export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

在运行它之前。 我也在 python 2 和 3 上用不同的 tensorflow 版本尝试过这个,我在 Ubuntu 16.04 上运行。在带有 tensorflow 0.9.0rc0 的 python 2.7 上,我设法得到了一个稍微不同的错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 633, in DecodeField
    if value._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 612, in DecodeRepeatedField
    if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1127, in InternalParse
    pos = field_decoder(buffer, new_pos, end, self, field_dict)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 489, in DecodeRepeatedField
    value.append(_ConvertToUnicode(buffer[pos:new_pos]))
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 469, in _ConvertToUnicode
    return local_unicode(byte_str, 'utf-8')
UnicodeDecodeError: 'utf8' codec can't decode byte 0x80 in position 18: 'utf8' codec can't decode byte 0x80 in position 18: invalid start byte in field: tensorflow.FunctionDef.Node.ret

我可以使用此代码解析其他 .pb 图形,例如 https://github.com/taey16/tf/blob/master/imagenet/classify_image_graph_def.pb

提前致谢。

【问题讨论】:

    标签: python python-2.7 python-3.x tensorflow protocol-buffers


    【解决方案1】:

    这里的问题是你试图解析一个SavedModel 协议缓冲区,就好像它是一个GraphDef。尽管SavedModel 包含GraphDef,但它们具有不同的二进制格式。以下代码,使用tf.saved_model.loader.load() 应该可以工作:

    import tensorflow as tf
    
    with tf.Session(graph=tf.Graph()) as sess:
        tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING], "models/TEST-3") 
    

    【讨论】:

    • 谢谢。我想获得一个序列化的 GraphDef,我设法使用:sess.graph_def.SerializeToString() 然后加载它:graph_def = tf.GraphDef() graph_def.ParseFromString(graph_string) tf.import_graph_def(graph_def, name=”),因为我需要像那样加载它。现在的问题是变量没有初始化。我想在 Spark 中使用该模型并遵循本指南databricks.com/blog/2016/01/25/…
    • GraphDef 本身并没有提供足够的信息来初始化变量。该教程已经过时了......它早于SavedModel,这使得做这种事情变得更加容易。在我的答案中使用命令加载SavedModel 就足够了,因为它会为您初始化变量。
    • 是的,这行得通,但现在的问题是它非常慢,因为它每次都从磁盘读取它,因为我在不同的进程中运行它并且我不能重用相同的 tf 会话。有没有办法从字符串中加载 SavedModel 而不是从目录中读取它?或者可能以其他方式?
    • 如果您查看the implementation,您可能能够提取解析SavedModel 原型的代码并运行一次以获得GraphDef。 (然而,让检查点加载代码在不同的会话中工作是很困难的,因为Saver 实现本质上是基于文件的。)
    【解决方案2】:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fileContent)
    

    这里的 fileContent 应该是一个**“Frozen Graph”。 Tensorflow 也提供了相同的 api,参考Tensorflow freeze_graph API

    另一种创建冻结图的方法是:

    with tf.Session(graph=tf.Graph()) as sess:
          saver = tf.train.import_meta_graph(<.meta file>)
          saver.restore(sess, <checkpoint>)
          output_graph_def = tf.graph_util.convert_variables_to_constants(
                        sess,
                        tf.get_default_graph().as_graph_def(),
                        [comma separated output nodes name]
                    ) 
          # Saving "output_graph_def " in a file and generate frozen graph.
          with tf.gfile.GFile('frozen_graph.pb', "wb") as f:
          f.write(output_graph_def.SerializeToString())
    

    使用 freeze_graph.pb 作为

    graph_def.ParseFromString("frozen_graph.pb")
    

    所以首先使用 Saver 对象生成 .meta 和其他文件。完成后创建冻结图。

    【讨论】:

    • 你的方法 graph_def.ParseFromString(filepath) 引发 TypeError: a bytes-like object is required, not 'str'
    猜你喜欢
    • 2017-07-17
    • 1970-01-01
    • 2017-09-26
    • 1970-01-01
    • 2012-10-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多