【问题标题】:Tensorflow: how to use pretrained weights in new graph?Tensorflow:如何在新图中使用预训练的权重?
【发布时间】:2018-05-08 01:43:30
【问题描述】:

我正在尝试使用带有 python 框架的 tensorflow 构建一个带有 CNN 的对象检测器。我想训练我的模型首先只进行对象识别(分类),然后使用 pretarined 模型的几个卷积层训练它来预测边界框。我需要替换全连接层,可能还有一些最后的卷积层。因此,出于这个原因,我想知道是否可以将 only 权重从用于训练对象分类器的张量流图中导入到我将训练以进行对象检测的新定义图。所以基本上我想做这样的事情:

# here I initialize the new graph
conv_1=tf.nn.conv2d(in, weights_from_old_graph)
conv_2=tf.nn.conv2d(conv_1, weights_from_old_graph)
...
conv_n=tf.nn.nnconv2d(conv_n-1,randomly_initialized_weights)
fc_1=tf.matmul(conv_n, randomly_initalized_weights)

【问题讨论】:

  • 您可能想阅读以下内容:Choose Variables to Save and Restore
  • @Aechlys,哦,是的,谢谢。但是我以前见过这个,据我所知,这种方法意味着我应该只保存那些我想恢复的变量,但是为了进行实验,我想保存所有变量,然后选择我想在新图表。

标签: python tensorflow


【解决方案1】:

使用不带参数的 saver 来保存整个模型。

tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.initializers.random_normal)
v2 = tf.get_variable("v2", [5], initializer = tf.initializers.random_normal)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_path='./test-case.ckpt')

    print(v1.eval())
    print(v2.eval())
saver = None
v1 = [ 2.1882825   1.159807   -0.26564872]
v2 = [0.11437789 0.5742971 ]

然后在要恢复到特定值的模型中,将要恢复的变量名列表或{"variable name": variable} 的字典传递给Saver

tf.reset_default_graph()
b1 = tf.get_variable("b1", [3], initializer= tf.initializers.random_normal)
b2 = tf.get_variable("b2", [3], initializer= tf.initializers.random_normal)
saver = tf.train.Saver(var_list={'v1': b1})

with tf.Session() as sess:
  saver.restore(sess, "./test-case.ckpt")
  print(b1.eval())
  print(b2.eval())
INFO:tensorflow:Restoring parameters from ./test-case.ckpt
b1 = [ 2.1882825   1.159807   -0.26564872]
b2 = FailedPreconditionError: Attempting to use uninitialized value b2

【讨论】:

  • 天哪,现在我明白了他们在那篇教程中写的内容......我现在觉得有点愚蠢,哈哈。非常感谢)
【解决方案2】:

虽然我同意 Aechlys 恢复变量。当我们想要修复这些变量时,问题就更难了。例如,我们训练了这些变量,我们想在另一个模型中使用它们,但这次没有训练它们(像迁移学习一样训练新变量)。你可以看到我发布的答案here

快速示例:

 with tf.session() as sess:
    new_saver = tf.train.import_meta_graph(pathToMeta)
    new_saver.restore(sess, pathToNonMeta) 

    weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0")) 


 tf.reset_default_graph() #this will eliminate the variables we restored


 with tf.session() as sess:
    weights = 
       {
       '1': tf.Variable(weight1 , name='w1-bis', trainable=False)
       }
...

我们现在确定恢复的变量不是图表的一部分。

【讨论】:

  • 我离实现这个还很远,现在,我只是在猜测这个。但是,您的方法与 Aechlys 之间的区别在于,您恢复旧模型的图形并从该图形中获取张量,然后仅使用这些张量的数值创建新图形。在他的方法中,他定义了新图并直接恢复权重。所以也许在定义新图的过程中设置 tf.get_variable(trainable=False) 是可能的?
  • 我的假设是 .meta 文件包含图形中操作和变量的所有定义,包括变量的属性。并且 .data (最大的那个)文件包含数值,可能当你恢复它时,它只是用相应的数值初始化图中同名的权重。但同样,这只是猜测)
  • 确实,flag 已经被添加了!