【问题标题】:Tensorflow, how to access all the middle states of an RNN, not just the last stateTensorflow,如何访问一个RNN的所有中间状态,而不仅仅是最后一个状态
【发布时间】:2017-06-22 15:25:47
【问题描述】:

我的理解是tf.nn.dynamic_rnn 在每个时间步返回一个 RNN 单元(例如 LSTM)的输出以及最终状态。我如何才能在所有时间步骤中访问单元状态,而不仅仅是最后一个?例如,我希望能够对所有隐藏状态进行平均,然后在后续层中使用它。

以下是我如何定义一个 LSTM 单元,然后使用 tf.nn.dynamic_rnn 展开它。但这仅给出了 LSTM 的最后一个单元状态。

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)

【问题讨论】:

标签: python tensorflow


【解决方案1】:

这样的事情应该可以工作。

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)

这使用连接状态,因为我不知道您是否可以存储任意数量的元组状态。 states 变量的形状为 (batch_size, max_time_size, state_size)。

【讨论】:

  • 您能详细说明一下这个 CustomRNN 代码是如何返回中间状态的吗?我正在尝试理解您的代码!
  • LSTM 状态是输出 (m) 和隐藏状态 (c) 的组合。此代码获取输出 (m) 并将其替换为连接状态 (c + m)。忽略批量大小,输出是 [(c1 + m1), (c2 + m2), ... ] 的列表,而不是 [m1, m2, ...]。
  • 那么,这用隐藏状态(c)替换了实际输出(m),对(return next_state, next_state而不是return m, new_state)?你在哪里连接输出和隐藏状态(m + c)?
  • new_state 对于 LSTM 来说是 (c + m),因此返回 new_state, new_state 会将输出 m 替换为 (c + m)。请参阅实现中的this line
【解决方案2】:

我会把你指向这个thread(我的亮点):

如果每个时间步都需要 c 和 h 状态,您可以编写 LSTMCell 的变体,将两个状态张量作为输出的一部分返回。如果只需要h状态,那就是每个时间步的输出

正如@jasekp 在其评论中所写,输出实际上是状态的h 部分。然后dynamic_rnn 方法将跨时间堆叠所有h 部分(参见this file_dynamic_rnn_loop 的字符串文档):

def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  """Internal implementation of Dynamic RNN.
    [...]
    Returns:
    Tuple `(final_outputs, final_state)`.
    final_outputs:
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.

【讨论】:

  • LSTMCell 只是一个单元格,如果我没记错的话,它会返回状态和输出。我认为tf.nn.dynamic_rnn 的展开部分只返回最后一步。那么,我需要修改它吗?奇怪的是,目前还没有更高级别的解决方案。
猜你喜欢
  • 2021-08-11
  • 1970-01-01
  • 1970-01-01
  • 2017-02-04
  • 1970-01-01
  • 2022-12-10
  • 1970-01-01
  • 2018-04-30
  • 1970-01-01
相关资源
最近更新 更多