【问题标题】:Tensorflow: Download and run pretrained VGG or ResNet modelTensorflow:下载并运行预训练的 VGG 或 ResNet 模型
【发布时间】:2019-02-08 14:46:17
【问题描述】:

让我们从头开始。到目前为止,我自己已经在 Tensorflow 中创建并训练了小型网络。在训练期间,我保存我的模型并在我的目录中获取以下文件:

model.ckpt.meta
model.ckpt.index
model.ckpt.data-00000-of-00001

稍后,我加载保存在network_dir 中的模型进行一些分类并提取模型的可训练变量。

saver = tf.train.import_meta_graph(network_dir + ".meta")
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="NETWORK")

现在我想使用更大的预训练模型,例如 VGG16 或 ResNet,并希望使用我的代码来实现。我想加载预训练模型,如我自己的网络,如上所示。

在这个网站上,我发现了很多预训练模型:

https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models

我下载了 VGG16 检查点,发现这些只是训练出来的参数。

我想知道如何或在哪里可以获得这些预训练网络的保存模型或图形结构?例如,如何使用没有model.ckpt.metamodel.ckpt.indexmodel.ckpt.data-00000-of-00001 文件的VGG16 检查点?

【问题讨论】:

    标签: python tensorflow


    【解决方案1】:

    在权重链接旁边,有一个指向定义模型的代码的链接。例如,对于 VGG16:Code。使用代码创建模型并从检查点恢复变量:

    import tensorflow as tf
    
    slim = tf.contrib.slim
    
    image = ...  # Define your input somehow, e.g with placeholder
    logits, _ = vgg.vgg_16(image)
    predictions = tf.argmax(logits, 1)
    variables_to_restore = slim.get_variables_to_restore()
    
    saver = tf.train.Saver(variables_to_restore)
    with tf.Session() as sess:
        saver.restore(sess, "/path/to/model.ckpt")
    

    因此,vgg.py 中包含的代码将为您创建所有变量。使用 tf-slim 帮助程序,您可以获得列表。然后,只需按照通常的程序。上面有a similar question

    【讨论】:

    • 我该怎么做?你能更具体一点吗?我是初学者。
    • 非常感谢。我导入了 VGG16 代码做import vgg。但是,我收到以下错误NameError: name 'slim' is not defined。如何解决这个问题来获取网络的所有变量?
    • slim = tf.contrib.slim
    • 抱歉,问了这么愚蠢的问题。
    • 您能否展示我如何提取 VGG16 的所有激活。通常我通过activations=tf.get_collection('Activations') 得到它们,但这在这里似乎是不可能的。获得激活的正确方法是什么?
    猜你喜欢
    • 2018-10-16
    • 2018-07-25
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-06-09
    • 2021-01-31
    • 2020-08-16
    相关资源
    最近更新 更多