【问题标题】:How can I access the weights of a recurrent cell in Tensorflow?如何在 Tensorflow 中访问循环单元的权重?
【发布时间】:2018-08-01 01:53:27
【问题描述】:

提高深度 Q 学习任务稳定性的一种方法是为网络维护一组目标权重,这些权重更新缓慢并用于计算 Q 值目标。因此,在学习过程的不同时间,前向传递中使用了两组不同的权重。对于普通的 DQN,这并不难实现,因为权重是可以在 feed_dict 中设置的张量流变量,即:

sess = tf.Session()
input = tf.placeholder(tf.float32, shape=[None, 5])
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1)
bias = tf.Variable(tf.constant(0.1, shape=[4])
output = tf.matmul(input, weights) + bias
target = tf.placeholder(tf.float32, [None, 4])
loss = ...

...

#Here we explicitly set weights to be the slowly updated target weights
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias})

# Targets for the learning procedure are computed using this output.

....

#Now we run the learning procedure, using the most up to date weights,
#as well as the previously computed targets
sess.run(loss, feed_dict={input: states, target: targets})

我想在 DQN 的循环版本中使用这种目标网络技术,但我不知道如何访问和设置循环单元内使用的权重。具体来说,我使用的是 tf.nn.rnn_cell.BasicLSTMCell,但我想知道如何对任何类型的循环单元执行此操作。

【问题讨论】:

    标签: python machine-learning tensorflow reinforcement-learning


    【解决方案1】:

    BasicLSTMCell 不会将其变量作为其公共 API 的一部分公开。我建议您在图表中查找这些变量的名称并提供这些名称(这些名称不太可能更改,因为它们位于检查点中,更改这些名称会破坏检查点的兼容性)。

    或者,您可以制作一个 BasicLSTMCell 的副本,它确实公开了变量。我认为这是最干净的方法。

    【讨论】:

    • 这行得通,谢谢亚历山大。对于任何想要更多细节的人,当您将循环单元输入tf.nn.dynamicrnn() 时,会创建权重和偏差变量。在会话中运行tf.initialize_all_variables() 后,将有两个新的可训练张量,如果你运行tf.trainable_variables(),你可以看到它们。在我的例子中,它们被命名为 RNN/BasicLSTMCell/Linear/Matrix:0RNN/BasicLSTMCell/Linear/Bias:0
    【解决方案2】:

    您可以使用下面的行来获取图中的变量

    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    

    然后您可以检查这些变量以了解它们是如何变化的

    【讨论】: