【问题标题】:why set return_sequences=True and stateful=True for tf.keras.layers.LSTM?为什么要为 tf.keras.layers.LSTM 设置 return_sequences=True 和 stateful=True?
【发布时间】:2019-03-22 08:56:18
【问题描述】:

我正在学习tensorflow2.0,关注tutorial。在rnn的例子中,我找到了代码:

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim, 
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.LSTM(rnn_units, 
                        return_sequences=True, 
                        stateful=True, 
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

我的问题是:为什么代码设置参数return_sequences=Truestateful=True?使用默认参数怎么样?

【问题讨论】:

    标签: tensorflow keras lstm recurrent-neural-network


    【解决方案1】:

    本教程中的示例是关于文本生成的。这是批量输入网络的输入:

    (64, 100, 65) # (batch_size, sequence_length, vocab_size)

    1. return_sequences=True

    由于目的是为每个时间步预测一个字符,即对于序列中的每个字符,需要预测下一个字符。

    因此,参数 return_sequences=True 设置为 true,以获得 (64, 100, 65) 的输出形状。如果此参数设置为 False,则仅返回最后一个输出,因此对于 64 个批次,输出将为 (64, 65),即对于每个 100 个字符的序列,仅返回最后一个预测字符。

    1. stateful=True

    从文档中, “如果为 True,则批次中索引 i 处每个样本的最后状态将用作下一批中索引 i 样本的初始状态。”

    在本教程的下图中,您可以看到设置有状态有助于 LSTM 通过提供先前预测的上下文来做出更好的预测。

    【讨论】:

    • 你能引用拍摄图像的教程吗?我的理解是stateful=True 跨批次共享上下文,而不是预测
    • 它来自问题中提到的同一教程:tensorflow.org/tutorials/text/…
    • @rtrtrt 每个预测都是在不同的批次中进行的,例如第一批将有前 4000 个样本的第一个字符,下一批将有接下来的 4000 个样本的第二个字符
    • (64, 100, 65) # (batch_size, sequence_length, vocab_size) 中不应该是number_of_features_per_time_step 而不是vocab_size?因为,当return_sequences=False 它只返回最后一个时间步的特征(对于批次中的每个示例),而不是 100 个时间步中的每一个 [并且维度变为 (64, 65)]。
    【解决方案2】:

    返回序列

    让我们看一下使用 LSTM 构建的典型模型架构。

    序列到序列模型:

    我们输入一系列输入 (x),一次一批,每个 LSTM 单元返回一个输出 (y_i)。因此,如果您的输入大小为batch_size x time_steps X input_size,那么 LSTM 输出将为batch_size X time_steps X output_size。这被称为序列到序列模型,因为输入序列被转换为输出序列。该模型的典型用途是标注器(POS 标注器、NER 标注器)。在 keras 中,这是通过设置 return_sequences=True 来实现的。

    序列分类 - 多对一架构

    在多对一架构中,我们使用最后一个 LSTM 单元的输出状态。这种架构通常用于分类问题,例如预测电影评论(表示为单词序列)是否为 +ve of -ve。在 keras 中,如果我们设置return_sequences=False,模型只返回最后一个 LSTM 单元的输出状态。

    有状态

    一个 LSTM 单元由许多门组成,如下图 this blog post 所示。前一个单元的状态/门用于计算当前单元的状态。在 keras 中,如果 stateful=False 则在每批后重置状态。如果stateful=True,上一批索引i 的状态将用作下一批索引i 的初始状态。所以状态信息会在批次之间以stateful=True 传播。检查此link 以通过示例解释有状态的有用性。

    【讨论】:

    • “对于索引 i”是什么意思?
    • @rtrtrt 表示 LSTM 单元在时间步解开 i
    【解决方案3】:

    让我们看看使用参数时的区别:

    tf.keras.backend.clear_session()
    tf.set_random_seed(42)
    X = np.array([[[1,2,3],[4,5,6],[7,8,9]],[[1,2,3],[4,5,6],[0,0,0]]], dtype=np.float32)
    model = tf.keras.Sequential([tf.keras.layers.LSTM(4, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')])
    print(tf.keras.backend.get_value(model(X)).shape)
    # (2, 3, 4)
    print(tf.keras.backend.get_value(model(X)))
    # [[[-0.16141939  0.05600287  0.15932009  0.15656665]
    #  [-0.10788933  0.          0.23865232  0.13983202]
       [-0.          0.          0.23865232  0.0057992 ]]
    
    # [[-0.16141939  0.05600287  0.15932009  0.15656665]
    #  [-0.10788933  0.          0.23865232  0.13983202]
    #  [-0.07900514  0.07872108  0.06463861  0.29855606]]]
    

    因此,如果将 return_sequences 设置为 True,则模型会返回其预测的完整序列。

    tf.keras.backend.clear_session()
    tf.set_random_seed(42)
    model = tf.keras.Sequential([
    tf.keras.layers.LSTM(4, return_sequences=False, stateful=True, recurrent_initializer='glorot_uniform')])
    print(tf.keras.backend.get_value(model(X)).shape)
    # (2, 4)
    print(tf.keras.backend.get_value(model(X)))
    # [[-0.          0.          0.23865232  0.0057992 ]
    #  [-0.07900514  0.07872108  0.06463861  0.29855606]]
    

    因此,如文档所述,如果 return_sequences 设置为 False,则模型仅返回最后一个输出。

    至于stateful,这有点难以深入。但本质上,当有多个输入批次时,批次i 的最后一个单元状态将是批次i+1 的初始状态。但是,我认为使用默认设置会更好。 ​

    【讨论】:

      猜你喜欢
      • 2021-08-14
      • 1970-01-01
      • 2011-05-26
      • 2014-11-05
      • 2010-11-10
      • 2023-03-09
      • 2013-02-02
      • 2016-09-17
      相关资源
      最近更新 更多