【问题标题】:Setting up an initial state in LSTM in Tensorflow 1.9在 Tensorflow 1.9 中的 LSTM 中设置初始状态
【发布时间】:2018-07-19 07:26:34
【问题描述】:

我尝试制作一个堆叠 2 层的简单 LSTM 网络。为此,我使用 MultiRNNCell。我遵循了教程和其他堆栈主题,但我仍然无法运行我的网络。您可以在下面找到我在堆栈上找到的初始状态声明。

cell_count = 10 # timesteps
num_hidden = 4 # hidden layer num of features
num_classes = 1 
num_layers = 2
state_size = 4

init_c = tf.Variable(tf.zeros([batch_size, cell_count]), trainable=False)
init_h = tf.Variable(tf.zeros([batch_size, cell_count]), trainable=False)
initial_state = rnn.LSTMStateTuple(init_c, init_h) #[num_layers, 2, batch_size, state_size])

您可以在下面找到我的模型的样子:

def generate_model_graph(self, data):

    L1 = self.generate_layer(self.cell_count)
    L2 = self.generate_layer(self.cell_count)

    #outputs from L1
    L1_outs, _ = L1(data, self.initial_state)

    #reverse output array
    L2_inputs = L1_outs[::-1]

    L2_outs, _ = L2(L2_inputs, self.initial_state)
    predicted_vals = tf.add(tf.matmul(self.weights["out"], L2_outs), self.biases["out"])
    L2_out = tf.nn.sigmoid(predicted_vals)
    return L2_out



def generate_layer(self, size):
    cells = [rnn.BasicLSTMCell(self.num_hidden) for _ in range(size)]
    return rnn.MultiRNNCell(cells)

并运行会话:

def train_model(self, generator):
    tr, cost = self.define_model()

    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for _ in range(self.n_epochs):
            batch_x, batch_y = self._prepare_data(generator)
            init_state = tf.zeros((self.cell_count, self.num_hidden))
            t, c = sess.run([tr, cost], feed_dict={self.X: batch_x, self.Y:batch_y, self.initial_state:init_state})
            print(c)

很遗憾,我仍然收到错误消息 'Variable' object is not iterable

  File "detector_lstm_v2.py", line 104, in <module>
    c.train_model(data_gen)
  File "detector_lstm_v2.py", line 38, in train_model
    tr, cost = self.define_model()
  File "detector_lstm_v2.py", line 51, in define_model
    predicted_vals = self.generate_model_graph(self.X)
  File "detector_lstm_v2.py", line 65, in generate_model_graph
    L1_outs, _ = L1(data, self.initial_state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 232, in __call__
    return super(RNNCell, self).__call__(inputs, state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/layers/base.py", line 329, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 703, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1325, in call
    cur_inp, new_state = cell(cur_inp, cur_state)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 339, in __call__
    *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/layers/base.py", line 329, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py", line 703, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 633, in call
    c, h = state
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py", line 491, in __iter__
    raise TypeError("'Variable' object is not iterable.")
TypeError: 'Variable' object is not iterable.

有人知道如何解决这个问题吗?

【问题讨论】:

    标签: python python-3.x tensorflow lstm


    【解决方案1】:

    您正在创建一个多层 rnn 单元,但您正在传递一个状态。

    使用它来创建你的状态:

    initial_state = L1.zero_state()
    

    如果需要变量,也可以使用它来初始化变量。

    你的代码中有一些“命名”问题,让我觉得你在这里误解了一些东西。

    有不同的参数:

    1. 层的隐藏大小:它是 RNNCell 构造函数的units 属性。你细胞的所有状态都需要有一个形状 [bacth_size, hidden_​​size](而不是细胞计数)
    2. 代码中的cell_count 不是确定序列的长度,而是确定网络的“深度”。
    3. 序列的长度由您传递给模型的输入序列(需要是张量列表)自动确定。

    我建议您看一下关于循环神经网络的 TF 教程here 或者这个答案here 以了解 RNNCell 是 w.r.t。 RNN 文献(它是一个层而不是单个单元)。

    【讨论】:

      最近更新 更多