【问题标题】:Truncated Backpropagation in keras with one sequence per batchkeras 中的截断反向传播,每批一个序列
【发布时间】:2019-04-11 18:38:08
【问题描述】:

如果我理解正确,要在 keras 中执行 TBPTT,我们必须将序列拆分为 k 个时间步长的较小部分。根据 keras 的文档,要在序列的所有部分重用 LSTM 的状态,我们必须使用 stateful 参数:

您可以将 RNN 层设置为“有状态”,这意味着为一个批次中的样本计算的状态将被重新用作下一批中的样本的初始状态。这假设不同连续批次中的样本之间存在一对一的映射关系。

所以,如果我理解正确,第一批的第一个样本是第一个序列的第 1 部分,第二批的第一个样本是第一个序列的第二部分,等等。我有 125973 个长度为 1000 的序列我分成 40 个 k=25 时间步长的序列。所以我的模型应该训练 40 个批次,包含 25 个时间步长的 125973 个序列。我的问题是我的GPU(quadro K2200,我很穷)的内存,125973的批量大小似乎太多了。我想知道是否可以将 LSTM 的状态保持在同一个批次中并在批次之间进行重置,所以我的批次大小应该是 40 和 125973 批次。

这是我的模型:

model = Sequential()
model.add(Embedding(len(char_to_num), 200, mask_zero=True, batch_input_shape=(batch_size, k)))
model.add(Dropout(0.5))
model.add(LSTM(512, activation='relu', return_sequences=True, stateful=True))
model.add(Dropout(0.5))
model.add(TimeDistributed(Dense(len(char_to_num), activation='softmax')))

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
model.summary()

2021 年编辑
今年已经给出了最近的答案,但这是一个古老的问题。与此同时,图书馆、DL 和 NLP 的状态发生了很大变化,我已经从 LSTM 转向了 Transformers。我已经很多年没有使用过 LSTM,也没有计划也没有时间测试发布的答案。

【问题讨论】:

  • 你得到答案了吗?

标签: python keras deep-learning backpropagation


【解决方案1】:

到目前为止,您的批量大小是灵活的,它必须除以 P = 125973。如果没有这样的数字(例如,因为 P 是质数),那么只需添加每个填充有千个零的虚拟序列。如果添加了虚拟序列,请确保在训练期间通过将适当的“sample_weights”nd-array 添加到 model.fit() 来忽略它们(其中真实序列用“1”屏蔽,虚拟序列用“0”屏蔽),并且调用 model.compile(.., sample_weight_mode='temporal')。

然后,要在批次之间重置状态,请使用 keras 回调:

# N must be divisible by batch_size
N = 40*126000  # number of time series snippets (sequences + dummies)
batch_size = 50  # processing 50 sequences at a time

class StateResetter(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs={}):
        # reset states if we processed a set of sequences
        if (batch+1) % 40 == 0:
            self.model.get_layer('my_lstm_layer').reset_states()

# input_data.shape = (N, 25, num_features)
model.fit(input_data, labels, batch_size=batch_size, 
          callbacks=[StateResetter], sample_weight=sample_weight)

我想你应该能够弄清楚如何相应地调整 input_data。

【讨论】:

    【解决方案2】:

    我想知道是否可以将 LSTM 的状态保持在同一个批次中并在批次之间重置它...

    这是为了更好地训练 LSTM 模型而采取的方法。这是因为批次中的样本将在时间上彼此相邻,并且当以有状态的方式对每个批次进行训练时,网络可以得到很好的训练。具有较小批量大小的内存节省是一个理想的副作用。

    如@Kirgsn 所示,可以在每批之后重置状态。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-05-05
      • 2020-11-04
      • 2019-05-23
      • 1970-01-01
      • 2017-09-02
      • 2021-03-19
      • 1970-01-01
      相关资源
      最近更新 更多