【问题标题】:restoring weights of an already saved Tensorflow .pb model恢复已保存的 Tensorflow .pb 模型的权重
【发布时间】:2020-11-12 05:19:40
【问题描述】:

我在这里看到了很多关于恢复已保存的 TF 模型的帖子,但没有人能回答我的问题。 Using TF 1.0.0

具体来说,我有兴趣查看inceptionv3 模型的权重,该模型在.pb 文件here 中公开提供。我设法使用一小段 Python 代码将其恢复,并且可以访问 tensorboard 中的图形高级视图:

from tensorflow.python.platform import gfile

INCEPTION_LOG_DIR = '/tmp/inception_v3_log'

if not os.path.exists(INCEPTION_LOG_DIR):
    os.makedirs(INCEPTION_LOG_DIR)
with tf.Session() as sess:
    model_filename = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _= tf.import_graph_def(graph_def,name='')
    writer = tf.train.SummaryWriter(INCEPTION_LOG_DIR, graph_def)
    writer=tf.summary.FileWriter(INCEPTION_LOG_DIR, graph_def)
    writer.close()

但是,我无法访问任何层的权重。

tensors= tf.import_graph_def(graph_def,name='')

返回空,即使我添加任意return_elements=。它有任何重量吗?如果是,这里的适当程序是什么?谢谢。

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    使用此代码打印张量的值:

    with tf.Session() as sess:
        print sess.run('your_tensor_name')
    

    您可以使用此代码检索张量名称:

        op = sess.graph.get_operations()
        for m in op : 
        print(m.values())
    

    【讨论】:

    • @thanks @vahid 回答。那么返回值是张量的权重吗?即:print sess.run('mixed_6/join/concat_dim:0') 返回 3。我们如何解释这个?
    • 作为后续,你能看看我的新问题吗:stackoverflow.com/questions/45408463/…
    【解决方案2】:

    恢复权重和打印权重是有区别的。前者表示希望从已保存的 ckpt 文件中导入权重值以进行再训练或推理,而后者可能用于检查。 .pb 文件也将模型参数编码为 tf.constant() ops。结果,模型参数不会出现在 tf.trainable_variables() 中,因此您不能直接使用.pb恢复权重。根据您的问题,我认为您只是想“查看”权重以进行检查。

    让我们首先从.pb 文件中加载图表。

    import tensorflow as tf
    from tensorflow.python.platform import gfile
    
    GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
    with tf.Session(config=config) as sess:
      print("load graph")
      with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        graph_nodes=[n for n in graph_def.node]
    

    现在,当您将图形冻结到 .pb 文件时,您的变量将转换为 Const 类型,并且作为训练变量的权重也将作为 Const 存储在 .pb 文件中。 graph_nodes 包含图中的所有节点。但我们对所有Const 类型的节点都感兴趣。

    wts = [n for n in graph_nodes if n.op=='Const']
    

    wts 的每个元素都是 NodeDef 类型。它有几个属性,例如名称,操作等。值可以提取如下 -

    from tensorflow.python.framework import tensor_util
    
    for n in wts:
        print "Name of the node - %s" % n.name
        print "Value - " 
        print tensor_util.MakeNdarray(n.attr['value'].tensor)
    

    希望这能解决您的问题。

    【讨论】:

    【解决方案3】:

    您可以使用此代码获取张量的名称。

    [tf.get_default_graph().as_graph_def().node 中张量的tensor.name]

    【讨论】:

    • 谢谢@Beta。但我的问题是找到一种方法来检索权重,它只返回所有张量的巨大列表:[[u'mixed_10/join/concat_dim', u'mixed_9/join/concat_dim', u'mixed_8/join/concat_dim',...]
    • @Amir:从您的问题来看,您似乎试图调用特定的权重名称。但可能是它们的名称不同。这就是我提到该代码的原因。 Valid给出了更详细的答案。
    【解决方案4】:

    打印 .pb 模型权重的小工具:

    import argparse
    
    import tensorflow as tf
    from tensorflow.python.framework import tensor_util
    
    
    def print_pb_weights(pb_filepath):
        graph_def = tf.GraphDef()
        with tf.gfile.GFile(pb_filepath, "rb") as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    
        for node in graph_def.node:
            if node.op == 'Const':
                print('-' * 60)
                print('op:', node.op)
                print('name:', node.name)
                arr = tensor_util.MakeNdarray(node.attr['value'].tensor)
                print('shape:', list(arr.shape))
                print(arr)
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('pb_filepath')
        args = parser.parse_args()
    
        print_pb_weights(args.pb_filepath)
    

    【讨论】:

    • 它是否适用于使用tensorflow.org/guide/saved_model 保存在TF v2 中的那些.pb 文件?
    • 如果 saved_model 只是 .pb 并且一些附加元数据存储在单独的文件中,我认为是的。
    • 试过了,还是不行。 tf.gfile.GFile() 在 TFv2 中完全改变了
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-01-16
    • 2016-08-16
    • 1970-01-01
    • 1970-01-01
    • 2019-09-29
    相关资源
    最近更新 更多