【问题标题】:How to extract cell state from a LSTM at each timestep in Keras?如何在 Keras 的每个时间步从 LSTM 中提取单元状态?
【发布时间】:2018-08-27 03:09:24
【问题描述】:

Keras 有没有办法在给定输入的每个时间步长检索 LSTM 层的单元状态(即 c 向量)?

似乎return_state 参数返回计算完成后的最后一个单元状态,但我还需要中间的。另外,我不想将这些单元格状态传递给下一层,我只想能够访问它们。

最好使用 TensorFlow 作为后端。

谢谢

【问题讨论】:

  • 您找到解决方案了吗?我现在正在研究完全相同的问题。
  • 我没有找到一种简单直观的方法来做到这一点。但是,如果您创建一个将 LSTM 层作为模型中唯一层的模型(只是复制权重)并将 return_state 设置为 true,您可以获得序列生成的最后一个单元状态。因此,您可以只处理序列直到给定时间步,以获得该时间步产生的单元状态。
  • 例如,如果您的序列最初有 100 个时间步,但您想知道第 40 个时间步之后的细胞状态,您只需删除最后 60 个时间步并在层中运行新序列。这是一个非常蹩脚的解决方案,但我认为唯一可行的解​​决方案。不过我没有尝试,因为我在我正在处理的项目中改变了我的方法。

标签: tensorflow keras deep-learning lstm rnn


【解决方案1】:

我知道现在已经很晚了,我希望这能有所帮助。

从技术上讲,您可以通过在调用方法中修改 LSTM 单元来实现您的要求。当您提供return_sequences=True 时,我对其进行修改并使其返回 4 维而不是 3。

代码

from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

            # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]

示例代码

# create a cell
test = Mod_LSTMCELL(100)

# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)

M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')

ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))

print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])

【讨论】:

  • 此解决方案极其复杂且不必要。请参阅我的答案以了解执行此操作的正确方法,该方法不需要子类化 LSTMCell
  • 不需要? ,我相信您的答案不适用于与 Keras API 集成,因为我的技术只是创建一个自定义单元格,以便您可以将其与现有 API 一起使用。例如,您可以将 model.add() 用于顺序模型。此外,该方法还创建了一个包含所有值 h,c 的张量,您可以将其用于构建巨大的模型,而无需每次使用时都执行 concat 层,从而使代码更清晰,更易于阅读。这并不复杂,你只需复制我的类并像使用普通 LSTM 单元和 RNN 一样使用它。
  • 您编写了一个自定义类来返回 [h,c] 状态,而 LSTMCell 实际上已经这样做了。为什么要添加不必要的复杂性来实现 LSTMCell 已经具有的功能。此外,您的代码甚至不返回激活。好的 - 您的代码可以与顺序 API 一起使用。我的代码是执行此操作的“正确”方法,应与功能性 api 或子类 tf.keras.Model 一起使用。
  • 相信你没有看过任何keras源码,我的实现基于keras代码是正确的。大多数代码保持不变,而我只是修改了一些代码以返回 C 状态,因此您在增加复杂性方面显然是错误的,因为我的代码只是实际 LSTM 所做的后端。接下来,如果您的模型无法在 keras API 上实现,那么使用像 Keras 这样的高级 API 有什么意义,我们可以使用纯低级 tenserflow 来实现所有这些。
  • 接下来,谈到您的实现复杂性,它不使用 keras 层类,当您尝试可视化计算图和调试时,这会使您的代码更难看,而 Keras 具有帮助可视化图形的层类,因此,我的实现可能看起来很混乱,但这不仅仅是可视化,而且可以与其他 keras API 无缝集成,而您的代码只是一个低级实现,没有可扩展性和低模态,这会导致代码的可重用性差,我相信让你的实现变得非常复杂,而不是我的。
【解决方案2】:

我一直在寻找解决此问题的方法,在阅读了在 tf.keras (https://www.tensorflow.org/api_docs/python/tf/keras/layers/AbstractRNNCell) 中创建自己的自定义 RNN 单元的指南后,我相信以下是最简洁易读的方法对于 TensorFlow 2:

import tensorflow as tf
from tensorflow.keras.layers import LSTMCell

class LSTMCellReturnCellState(LSTMCell):

    def call(self, inputs, states, training=None):
        real_inputs = inputs[:,:self.units] # decouple [h, c]
        outputs, [h,c] = super().call(real_inputs, states, training=training)
        return tf.concat([h, c], axis=1), [h,c]



num_units = 512
test_input = tf.random.uniform([5,100,num_units])

rnn = tf.keras.layers.RNN(LSTMCellReturnCellState(num_units),
                          return_sequences=True, return_state=True)

whole_seq_output, final_memory_state, final_carry_state = rnn(test_input)

print(whole_seq_output.shape)
>>> (5,100,1024)

# Hidden state sequence
h_seq = whole_seq_output[:,:,:num_units] # (5,100,512)

# Cell state sequence
c_seq = whole_seq_output[:,:,num_units:] # (5,100,512)

如上述解决方案中所述,您可以看到这样做的好处是可以轻松地将其包装到tf.keras.layers.RNN 中,作为普通LSTMCell 的一个插件。

这是一个Colab Notebook,代码按预期运行tensorflow==2.6.0

【讨论】:

    【解决方案3】:

    首先,tf.keras.layers.LSTM 无法做到这一点。您必须改用 LSTMCell 或子类 LSTM。其次,不需要子类 LSTMCell 来获得细胞状态的序列。每次调用 LSTMCell 时,它都会返回隐藏状态 (h) 和单元状态 (c) 的列表。 对于那些不熟悉 LSTMCell 的人,它接受当前的 [h, c] 张量和当前时间步的输入(它不能接受时间序列)并返回激活和更新的 [h,c]。 这是一个示例,展示了如何使用 LSTMCell 处理一系列时间步长并返回累积的单元状态。

    # example inputs
    inputs = tf.convert_to_tensor(np.random.rand(3, 4), dtype='float32')  # 3 timesteps, 4 features
    h_c = [tf.zeros((1,2)),  tf.zeros((1,2))]  # must initialize hidden/cell state for lstm cell
    h_c = tf.convert_to_tensor(h_c, dtype='float32')
    lstm = tf.keras.layers.LSTMCell(2)
    
    # example of how you accumulate cell state over repeated calls to LSTMCell
    inputs = tf.unstack(inputs, axis=0)
    c_states = []
    for cur_inputs in inputs:
        out, h_c = lstm(tf.expand_dims(cur_inputs, axis=0), h_c)
        h, c = h_c
        c_states.append(c)
    

    【讨论】:

      【解决方案4】:

      您可以通过在初始化程序中设置return_sequences = True 来访问任何RNN 的状态。你可以找到更多关于这个参数here的信息。

      【讨论】:

      • 据我了解,“return_sequence=True”返回所有隐藏状态(即“h”向量)。我想要访问的是 LSTM 的单元状态('c' 向量)。
      猜你喜欢
      • 2017-07-26
      • 1970-01-01
      • 1970-01-01
      • 2017-10-08
      • 1970-01-01
      • 1970-01-01
      • 2020-03-28
      • 2018-03-30
      • 2017-08-31
      相关资源
      最近更新 更多