【问题标题】:Masking zero inputs in LSTM in keras without using embedding在不使用嵌入的情况下在 keras 中屏蔽 LSTM 中的零输入
【发布时间】:2018-11-06 13:25:14
【问题描述】:

我正在 Keras 中训练 LSTM:

iclf = Sequential()
iclf.add(Bidirectional(LSTM(units=10, return_sequences=True, recurrent_dropout=0.3), input_shape=(None,2048)))
iclf.add(TimeDistributed(Dense(1, activation='sigmoid')))

每个单元格的输入是一个 2048 向量,该向量已知且无需学习(如果愿意,它们是输入句子中单词的 ELMo 嵌入)。因此,这里我没有嵌入层。

由于输入序列具有可变长度,因此使用pad_sequences 填充它们:

X = pad_sequences(sequences=X, padding='post', truncating='post', value=0.0, dtype='float32')

现在,我想告诉 LSTM 忽略这些填充元素。官方的做法是用mask_zero=True的Embedding层。但是,这里没有嵌入层。如何通知 LSTM 屏蔽零个元素?

【问题讨论】:

  • 你可以使用Masking层。您可能会发现this answer 的第二部分在这方面很有用。
  • 非常感谢@today。

标签: keras lstm embedding


【解决方案1】:

正如@Today 在评论中所建议的,您可以使用Masking 层。在这里,我添加了一个玩具问题。

# lstm autoencoder recreate sequence
from numpy import array
from keras.models import Sequential
from keras.layers import LSTM, Masking
from keras.layers import Dense
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
from keras.utils import plot_model
from keras.preprocessing.sequence import pad_sequences


# define input sequence
sequence = array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 
                  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
                  [0.3, 0.4, 0.5, 0.6]])
# make sure to use dtype='float32' in padding otherwise with floating points
sequence = pad_sequences(sequence, padding='post', dtype='float32')


# reshape input into [samples, timesteps, features]
n_obs = len(sequence)
n_in = 9
sequence = sequence.reshape((n_obs, n_in, 1))

# define model
model = Sequential()
model.add(Masking(mask_value=0, input_shape=(n_in, 1)))
model.add(LSTM(100, activation='relu', input_shape=(n_in,1) ))
model.add(RepeatVector(n_in))
model.add(LSTM(100, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(1)))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit(sequence, sequence, epochs=300, verbose=0)
plot_model(model, show_shapes=True, to_file='reconstruct_lstm_autoencoder.png')
# demonstrate recreation
yhat = model.predict(sequence, verbose=0)
print(yhat[0,:,0])

【讨论】:

    猜你喜欢
    • 2021-09-18
    • 2020-06-28
    • 2021-03-03
    • 1970-01-01
    • 1970-01-01
    • 2018-03-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多