【问题标题】:How to get weights from .pb model in Tensorflow如何从 Tensorflow 中的 .pb 模型中获取权重
【发布时间】:2017-09-09 05:36:52
【问题描述】:

我训练了一个模型,然后通过冻结该模型创建了一个 .pb 文件。 所以,我的问题是如何从 .pb 文件中获取权重,或者我必须做更多的过程来获取权重

@mrry,请指导我。

【问题讨论】:

  • 不幸的是,我不是很喜欢,但是冻结模型会给你一个 GraphDef;你可以parse a GraphDef in Python,它会有常量的值(包括你的冻结重量)。
  • 哦.. 非常感谢..

标签: python tensorflow


【解决方案1】:

让我们首先从.pb 文件中加载图表。

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

现在,当您将图形冻结到 .pb 文件时,您的变量将转换为 Const 类型,并且作为训练变量的权重也将作为 Const 存储在 .pb 文件中。 graph_nodes 包含图中的所有节点。但是我们对所有Const 类型的节点都感兴趣。

wts = [n for n in graph_nodes if n.op=='Const']

wts 的每个元素都是 NodeDef 类型。它有几个属性,例如名称,操作等。值可以提取如下 -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

希望这能解决您的问题。

【讨论】:

  • 如何加载 pb 文件然后使用它?像 network = load_pb_file(Alexnet.pb) network.run(image)
  • 你好?什么是一 .pb 文件?我该如何使用它?我有一个 AlexNet.pb
  • @kong 是的,pb文件就是冻结文件。 Alexnet.pb 适合您。
  • graph_def.ParseFromString
  • sess.graph.as_default() 似乎不需要,因为当使用with 打开会话时,它会自动设置为默认值。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-05-18
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2022-01-05
  • 1970-01-01
相关资源
最近更新 更多