【问题标题】:Is it possible to make a trainable variable not trainable?是否可以使可训练变量不可训练?
【发布时间】:2016-09-16 11:56:58
【问题描述】:

我在一个范围内创建了一个可训练变量。后来我进入同一个作用域,将作用域设置为reuse_variables,用get_variable检索同一个变量。但是,我无法将变量的可训练属性设置为False。我的get_variable 行是这样的:

weight_var = tf.get_variable('weights', trainable = False)

但变量'weights' 仍在tf.trainable_variables 的输出中。

我可以使用get_variable 将共享变量的trainable 标志设置为False 吗?

我想这样做的原因是我试图在我的模型中重用从 VGG 网络预训练的低级过滤器,我想像以前一样构建图表,检索权重变量,并分配VGG 将值过滤到权重变量,然后在接下来的训练步骤中保持它们不变。

【问题讨论】:

  • minimize() 函数中的 var_list 参数是指定仅对某些变量进行训练的标准位置。

标签: tensorflow pre-trained-model


【解决方案1】:

查看文档和代码后,我无法找到从 TRAINABLE_VARIABLES 中删除变量的方法。

会发生以下情况:

  • 第一次调用tf.get_variable('weights', trainable=True),变量被添加到TRAINABLE_VARIABLES的列表中。
  • 第二次调用tf.get_variable('weights', trainable=False) 时,您会得到相同的变量,但参数trainable=False 无效,因为该变量已经存在于TRAINABLE_VARIABLES 的列表中(并且无法删除它从那里)

第一个解决方案

当调用优化器的minimize 方法时(参见doc.),您可以将var_list=[...] 作为参数传递给您想要优化器的变量。

例如,如果你想冻结除最后两层之外的所有 VGG 层,你可以在 var_list 中传递最后两层的权重。

第二种解决方案

您可以使用tf.train.Saver() 保存变量并在以后恢复它们(请参阅this tutorial)。

  • 首先,您使用所有可训练的变量训练整个 VGG 模型。您可以通过调用 saver.save(sess, "/path/to/dir/model.ckpt") 将它们保存在检查点文件中。
  • 然后(在另一个文件中)使用不可训练变量训练第二个版本。您加载之前使用saver.restore(sess, "/path/to/dir/model.ckpt") 存储的变量。

或者,您可以决定仅将部分变量保存在检查点文件中。请参阅doc 了解更多信息。

【讨论】:

  • 谢谢。我尝试了同样的方法来查找是否可以从TRAINABLE_VARIABLES 的集合中删除一个变量,但不能。看起来定义一个可训练列表的列表对我来说是最好的。
  • 等一下,我刚刚发现 get_collection_ref() 返回了 trainable_variables 集合的引用,我应该能够更改和删除一些条目。我还没有测试过。无论如何,这不太重要。我总是可以过滤从get_collection() 获得的可训练变量并将其发送给优化器。
  • get_collection_ref() 更改可训练集合是否有任何副作用?
  • @Olivier 您不能从可训练列表中删除可训练变量,这是正确的。您可以执行trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 来获取可训练变量集合的引用,这是一个python 列表,然后使用pop 和正确的索引从那里删除变量。我对其进行了测试,它阻止了变量被训练。
  • 我正在使用从 TRAINABLE_VARIABLES 集合中删除变量的方法添加答案
【解决方案2】:

如果您只想训练或优化预训练网络的某些层,这就是您需要了解的内容。

TensorFlow 的 minimize 方法接受一个可选参数 var_list,这是一个要通过反向传播调整的变量列表。

如果不指定var_list,则优化器可以调整图中的任何 TF 变量。当您在var_list 中指定一些变量时,TF 会将所有其他变量保持不变。

这是jonbruner 和他的合作者使用的脚本示例。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

这会找到他们之前定义的所有变量名中包含“g_”的变量,将它们放入一个列表中,然后对它们运行 ADAM 优化器。

您可以在Quora这里找到相关答案

【讨论】:

    【解决方案3】:

    为了从可训练变量列表中删除一个变量,您可以首先通过以下方式访问该集合: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 在那里,trainable_collection 包含对可训练变量集合的引用。如果你从这个列表中弹出元素,例如trainable_collection.pop(0),你将从可训练变量中删除相应的变量,因此这个变量将不会被训练。

    虽然这适用于 pop,但我仍在努力寻找一种方法来正确使用 remove 和正确的参数,因此我们不依赖于变量的索引。

    编辑: 鉴于您有图中变量的名称(您可以通过检查图 protobuf 或使用 Tensorboard 更容易获得),您可以使用它来循环通过可训练变量列表,然后从可训练集合中删除变量。 示例:假设我希望对名称为 "batch_normalization/gamma:0""batch_normalization/beta:0" NOT 的变量进行训练,但它们已添加到 TRAINABLE_VARIABLES 集合中。我能做的是: `

    #gets a reference to the list containing the trainable variables
    trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
    variables_to_remove = list()
    for vari in trainable_collection:
        #uses the attribute 'name' of the variable
        if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
            variables_to_remove.append(vari)
    for rem in variables_to_remove:
        trainable_collection.remove(rem)
    

    ` 这将成功地从集合中移除这两个变量,它们将不再被训练。

    【讨论】:

    • 我试过你的方法,但不幸的是它不起作用,至少对于 TF 版本 1.6.0。准确的说,就是“打了胜仗,输了仗”!它确实使可训练变量无法被列为可训练变量 [如调用 tf.trainable_variables() 所示]...但 minimize() 方法继续进行,就好像什么都没发生一样 - 即这些变量仍在训练中!我的猜测是,当第一次调用 minimize() 时,它会拍摄要训练的变量的“快照”,然后不管图形集合 GraphKeys.TRAINABLE_VARIABLES 的后续变化如何,它都会继续使用它们
    • 感谢您注意到这一点。我一直使用 1.10 以上的版本。您是否有一些代码可以说明这种行为,以便我可以检查更高版本并尝试了解它为什么不起作用?
    【解决方案4】:

    您可以使用 tf.get_collection_ref 来获取集合的引用而不是 tf.get_collection

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-06-30
      • 1970-01-01
      • 1970-01-01
      • 2023-03-20
      • 2017-12-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多