【问题标题】:How to make the weights of an RNN cell untrainable in Tensorflow?如何在 Tensorflow 中使 RNN 单元的权重无法训练?
【发布时间】:2017-07-06 21:49:57
【问题描述】:

我正在尝试制作一个 Tensorflow 图,其中部分图已经预先训练并在预测模式下运行,而其余部分则在训练。我已经像这样定义了我的预训练单元:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(100)

state0 = tf.Variable(pretrained_state0,trainable=False)
state1 = tf.Variable(pretrained_state1,trainable=False)
pretrained_state = [state0, state1]

outputs, states = tf.contrib.rnn.static_rnn(rnn_cell, 
                                            data_input,
                                            dtype=tf.float32,
                                            initial_state = pretrained_state)

将初始变量设置为trainable=False 没有帮助。这些只是用于初始化权重,因此权重仍然会发生变化。

我仍然需要在训练步骤中运行优化器,因为我的模型的其余部分需要训练。但是如何防止优化器更改此 rnn 单元中的权重?

有没有等价于trainable=False的rnn_cell?

【问题讨论】:

  • 预训练模型的输出是你要训练的新模型的输入?如果是这样,为什么不预先计算出预训练模型?我的意思是只保留两个独立的图表。
  • @YuwenYan 你说得对,我能做到。我希望通过同时运行这两个图表来避免预先计算,因为确保所有数据排列起来会更简单,并且每次我想更改预训练模型时都会节省一个步骤

标签: machine-learning tensorflow deep-learning recurrent-neural-network


【解决方案1】:

您可以使用tf.stop_gradient() 来防止图形的pretrained 部分更新其权重,也可以使用optimiser() 来指定应训练图形的哪些部分。第二种方法涉及:

 #Create variable scope for the trainable parts of the graph: tf.variable_scope('train').

 # get trainable variables
 t_vars = tf.trainable_variables()
 train_vars = [var for var in t_vars if var.name.startswith('train')]
 # train only the variables of a particular scope
 opt = optimizer.minimize(cost, var_list=train_vars)

【讨论】:

  • 这看起来像我需要的。奇怪的是,它只阻止了 ANN 权重的训练,而不是 RNN 的权重。不过我会继续寻找。您是否曾经使用 RNN 权重进行此操作或看过示例?
  • 它适用于任何事情。使用 tf.variable_scope() 分配变量范围进行训练,然后仅使用范围内定义的那些variables 在 optimiser.minimize() 中更新