【问题标题】:RNN layer with unequal input and output lengths in TF/KerasTF/Keras 中输入和输出长度不等的 RNN 层
【发布时间】:2020-04-20 11:53:40
【问题描述】:

是否可以从 RNN 获得可变的输出长度,即 input_seq_length != output_seq_length?

这是一个显示 LSTM 输出形状的示例,test_rnn_output_v1 默认设置 - 仅返回最后一步的输出,test_rnn_output_v2 返回所有步骤的输出,即我需要类似test_rnn_output_v2 但输出形状为(None, variable_seq_length, rnn_dim) 或至少(None, max_output_seq_length, rnn_dim)

from keras.layers import Input
from keras.layers import LSTM
from keras.models import Model


def test_rnn_output_v1():
    max_seq_length = 10
    n_features = 4
    rnn_dim = 64

    input = Input(shape=(max_seq_length, n_features))
    out = LSTM(rnn_dim)(input)

    model = Model(inputs=[input], outputs=out)

    print(model.summary())

    # (None, max_seq_length, n_features)
    # (None, rnn_dim)


def test_rnn_output_v2():
    max_seq_length = 10
    n_features = 4
    rnn_dim = 64

    input = Input(shape=(max_seq_length, n_features))
    out = LSTM(rnn_dim, return_sequences=True)(input)

    model = Model(inputs=[input], outputs=out)

    print(model.summary())

    # (None, max_seq_length, n_features)
    # (None, max_seq_length, rnn_dim)


test_rnn_output_v1()
test_rnn_output_v2()

【问题讨论】:

标签: python machine-learning keras lstm recurrent-neural-network


【解决方案1】:

根据定义,RNN 层不能有不相等的输入和输出长度。但是,有一个技巧可以使用两个 RNN 层和中间的RepeatVector 层来实现不相等但固定的输出长度。这是一个最小示例模型,它接受可变长度的输入序列并产生具有固定和任意长度的输出序列:

import tensorflow as tf

max_output_length = 35

inp = tf.keras.layers.Input(shape=(None, 10))
x = tf.keras.layers.LSTM(20)(inp)
x = tf.keras.layers.RepeatVector(max_output_length)(x)
out = tf.keras.layers.LSTM(30, return_sequences=True)(x)

model = tf.keras.Model(inp, out)
model.summary()

这是模型摘要:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, None, 10)]        0         
_________________________________________________________________
lstm (LSTM)                  (None, 20)                2480      
_________________________________________________________________
repeat_vector (RepeatVector) (None, 35, 20)            0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 35, 30)            6120      
=================================================================
Total params: 8,600
Trainable params: 8,600
Non-trainable params: 0
_________________________________________________________________

此结构可用于序列到序列模型,其中输入序列的长度不一定与输出序列的长度相同。

【讨论】:

    猜你喜欢
    • 2019-07-31
    • 2020-08-19
    • 1970-01-01
    • 2020-07-03
    • 1970-01-01
    • 2019-08-30
    • 1970-01-01
    • 2020-06-10
    • 2022-01-15
    相关资源
    最近更新 更多