【问题标题】:Tensorflow: Importing pretrained model (mobilenet, .pb, .ckpt)Tensorflow:导入预训练模型(mobilenet、.pb、.ckpt)
【发布时间】:2018-06-17 20:46:58
【问题描述】:

我一直在研究在 tensorflow 中导入预训练模型的检查点。 这样做的目的是让我可以检查它的结构,并将其用于图像 分类。

具体来说,mobilenet 型号found here。我找不到任何 从各种 *.ckpt.* 文件导入模型的合理方法,并使用 一些论坛嗅探我发现了 Github 用户 StanislawAntol 写的一个要点 据称将所述文件转换为冻结模型 ProtoBuf (.pb) 文件。这 要点是here

运行脚本给了我一堆 .pb 文件,我希望我能工作 和。确实,this SO question 似乎回应了我的祈祷。

我一直在尝试以下代码的变体,但无济于事。任何物体 由tf.import_graph_def 返回,似乎是无类型。

import tensorflow as tf
from tensorflow.python.platform import gfile

model_filename = LOCATION_OF_PB_FILE

with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name='')

print(g_in)

这里有什么我遗漏的吗?整个转换为 .pb 是否错误?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    tf.import_graph_def 不返回图表,它填充范围内的“默认图表”。有关返回值的详细信息,请参阅documentation for tf.import_graph_def

    在您的情况下,您可以使用tf.get_default_graph() 检查图表。例如:

    with gfile.FastGFile(model_filename, 'rb') as f:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    
    g = tf.get_default_graph()
    print(len(g.get_operations()))
    

    请参阅documentation for tf.Graph,了解有关“默认图”概念和范围的更多详细信息。

    希望对您有所帮助。

    【讨论】:

    • 太棒了,这让我走上了正轨。谢谢。第 6 行的函数应该是 tf.get_default_graph()
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-01-20
    • 2019-01-29
    • 2020-08-16
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多