【问题标题】:Conceptual understanding of tf.nn.dynamic_rnn() "outputs" vs. "state"tf.nn.dynamic_rnn()“输出”与“状态”的概念理解
【发布时间】:2019-07-30 20:11:46
【问题描述】:

上下文

我正在阅读Hands on ML 的第二部分,并正在寻找关于在 RNN 的损失计算中何时使用“输出”以及何时使用“状态”的一些明确说明。

在书中(有书的人第 396 页),作者说,“请注意,全连接层连接到 states 张量,其中仅包含 RNN 的最终状态,”指到展开超过 28 个步骤的序列分类器。由于states 变量将具有len(states) == <number_of_hidden_layers>,因此在构建深度 RNN 时,我一直使用 states[-1] 仅连接到最后一层的最终状态。例如:

# hidden_layer_architecture = list of ints defining n_neurons in each layer
# example: hidden_layer_architecture = [100 for _ in range(5)]
layers = []
for layer_id, n_neurons in enumerate(hidden_layer_architecture):

    hidden_layer = tf.contrib.rnn.BasicRNNCell(n_neurons, 
                                               activation=tf.nn.tanh,                                                                                                                                                                     
                                               name=f'hidden_layer_{layer_id}')

    layers.append(hidden_layer)

recurrent_hidden_layers = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(recurrent_hidden_layers,
                                    X_, dtype=tf.float32)

logits = tf.layers.dense(states[-1], n_outputs, name='outputs')

考虑到作者之前的陈述,这可以正常工作。但是,我不明白什么时候会使用outputs 变量(tf.nn.dynamic_rnn() 的第一个输出)

我看过this question,它在回答细节方面做得很好,并提到,“如果你只对单元格的最后一个输出感兴趣,你可以切分时间维度来选择最后一个元素(例如outputs[:, -1, :])。”我推断这意味着类似于states[-1] == outputs[:, -1, :],在测试时它是错误的。为什么不是这样呢?如果输出是每个时间步的单元格的输出,为什么不是这种情况?一般...

问题

什么时候在损失函数中使用tf.nn.dynamic_rnn() 中的outputs 变量,什么时候使用states 变量?这如何改变网络的抽象架构?

任何澄清将不胜感激。

【问题讨论】:

  • 仅供参考:在 cmets 中“标记”某人无济于事。

标签: python tensorflow machine-learning neural-network


【解决方案1】:

这基本上把它分解了:

outputs:RNN 顶层的完整输出序列。这意味着,如果您使用MultiRNNCell,这将只是 top 单元格;这里没有来自下层细胞的任何东西。
一般来说,使用自定义 RNNCell 实现,这几乎可以是任何东西,但是几乎所有标准单元格都返回 states 的序列,但是您也可以自己编写一个自定义单元格在将其作为输出返回之前对状态序列进行一些处理(例如线性变换)。

state(请注意,这是文档中的名称,不是states)是最后时间步的完整状态。一个重要的区别是,在MultiRNNCell 的情况下,这将包含序列中所有 单元格的最终状态,而不仅仅是顶部!此外,此输出的精确格式/类型在很大程度上取决于所使用的 RNNCell(例如,它可能是张量,或张量的元组......)。

因此,如果您只关心MultiRNNCell 中最后一步的最顶层状态,那么您确实有两个应该相同的选项,归结为个人偏好/“清晰度”:

  • outputs[:, -1, :](假设批处理主要格式)仅从顶级状态序列中提取最后一个时间步。
  • state[-1] 仅从所有层的最终状态元组中提取顶级状态。

在其他情况下您可能没有此选择:

  • 如果确实需要全序列输出,需要使用outputs
  • 如果您需要MultiRNNCell 中来自较低层的最终状态,则需要使用state

至于相等性检查失败的原因:如果您实际使用==,我相信这会检查明显不同的张量对象的相等性。您可以改为尝试检查两个对象的 以了解一些简单的玩具场景(微小的状态大小/序列长度)——它们应该是相同的。

【讨论】:

  • 完美答案,谢谢。相等 (==) 测试正在测试对象的相等性,经过进一步检查,张量中的值是相同的。将output[:, -1, :] 代替state[-1] 代入损失函数会产生类似的结果。
猜你喜欢
  • 2018-01-27
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2020-06-25
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多