我认为不能修改tf.Variable 的trainable 属性。但是,有多种解决方法。
假设你有两个变量:
import tensorflow as tf
v1 = tf.Variable(tf.random_normal([2, 2]), name='v1')
v2 = tf.Variable(tf.random_normal([2, 2]), name='v2')
当您使用tf.train.Optimizer 类及其子类进行优化时,默认情况下它从tf.GraphKeys.TRAINABLE_VARIABLES 集合中获取变量。默认情况下,您使用 trainable=True 定义的每个变量都会添加到此集合中。您可以做的是清除此集合并仅将那些您愿意优化的变量附加到它。例如,如果我只想优化v1 而不是v2:
var_list = tf.trainable_variables()
print(var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>,
# <tf.Variable 'v2:0' shape=(2, 2) dtype=float32_ref>]
tf.get_default_graph().clear_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
cleared_var_list = tf.trainable_variables()
print(cleared_var_list)
# []
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, var_list[0])
updated_var_list = tf.trainable_variables()
print(updated_var_list)
# [<tf.Variable 'v1:0' shape=(2, 2) dtype=float32_ref>]
另一种方法是使用优化器的 var_list 关键字参数并传递您想要在训练期间更新的那些变量(在执行 train_op 期间):
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss, var_list=[v1])