【问题标题】:How to convert from trainable tensorflow variables to not trainable tensorflow variable?如何从可训练的张量流变量转换为不可训练的张量流变量?
【发布时间】:2019-04-11 04:02:41
【问题描述】:

我需要恢复一个 DNN(VGG16Net),并使用迁移学习来构建另一个网络。所以在这里我需要将一些过滤器、偏差变量从可训练的 tensorflow 变量转换为不可训练的变量(我使用的是原生 tensorflow 框架,而不是 keras 或任何更高级别的包)。

例如从卷积层 4_1 获取权重 我用了 conv4_3_filter=sess.graph.get_tensor_by_name('conv4_3/filter:0') 但变量“conv4_3_filter”始终是可训练变量。 所以,在这里我试图找到一种通用方法将任何张量流变量从可训练转换为不可训练。 我该如何解决这个问题?

【问题讨论】:

  • 如果可能的话,我不知道如何将可训练变量转换为不可训练变量,但是您可以让变量保持可训练但只训练其他变量,参见例如.stackoverflow.com/questions/54303730/…

标签: python tensorflow


【解决方案1】:

我认为不能修改tf.Variabletrainable 属性。但是,有多种解决方法。

假设你有两个变量:

import tensorflow as tf

v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')

当您使用tf.train.Optimizer 类及其子类进行优化时,默认情况下它从tf.GraphKeys.TRAINABLE_VARIABLES 集合中获取变量。默认情况下,您使用 trainable=True 定义的每个变量都会添加到此集合中。您可以做的是清除此集合并仅将那些您愿意优化的变量附加到它。例如,如果我只想优化v1 而不是v2

var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
#  <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]

tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []

tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])

updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]

另一种方法是使用优化器的 var_list 关键字参数并传递您想要在训练期间更新的那些变量(在执行 train_op 期间):

optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])

【讨论】:

    猜你喜欢
    • 2020-02-29
    • 1970-01-01
    • 2021-06-28
    • 1970-01-01
    • 1970-01-01
    • 2017-09-14
    • 1970-01-01
    • 1970-01-01
    • 2018-09-17
    相关资源
    最近更新 更多