【发布时间】:2020-02-19 04:20:18
【问题描述】:
我有一个非常可预测的序列。下面,你可以看到它的一部分:
deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4])
基本上,它是一个可变但仍可预测的 4 数量,然后是 8,然后每三个 8 有 28 个。
我想为在线预测构建一个非常简单的 LSTM 模型:每次有一个新数字到达时,它就会附加在双端队列的右侧。因此,LSTM 在由 deque 的 [0:seq_length] 元素组成的旧序列上进行训练,训练目标是 [seq_length] 元素。然后,对 [1:seq_length+1] 元素执行窗口移动和预测。最后,deque 最左边的元素被丢弃。我的直觉告诉我,这应该让网络记住序列。
但是,我的网络一直只回答 4 个。经过(很长)一段时间后,令人惊讶的是,它开始只回答 8 个,几乎所有时间都没有。然后,(很长)一段时间后,它又回到了只回答 4 的问题。
我的模型结构如图所示。当然,我已经尝试了 seq_length 和 lstm_cells 的不同值,但没有一个能成功。这些来自最新的运行:
seq_length = 64 #Length of the sequence to be inserted into the LSTM
vocab_size = 4 #Size of the final dense layer of the model
lstm_cells = 16 #Size of the LSTM layer
model = Sequential()
model.add(LSTM(lstm_cells, input_shape=(seq_length, 1)))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
以下是如何在模型上准备、训练和预测数据。变量 sequence 是本文开头显示的双端队列。我维护了一个列表 vocab = [4,8,28],它是建立在执行时间上的,因为看到新的数字,所以 vocab[i] 翻译类 i 插入其对应的序列号。然后我创建一个字典 legend 来做相反的事情。这或多或少是正在进行的在线循环:
while True:
# Receives new number and puts it into the deque:
sequence.append(generateNextNumber())
# At this point, please note that the length of the deque is seq_length + 1.
# Dictionary to convert numbers to classes:
legend = dict([(v, k) for k, v in enumerate(vocab)])
# Converts the deque into a list:
seq_list = list(sequence)
# Each iteration is comprised of 1 training and 1 prediction. These are the training sequence and target:
train_seq = [ [legend[i]] for i in seq_list[:seq_length] ]
train_target = legend[ seq_list[seq_length] ]
# And the prediction sequence just shifts the window by 1:
pred_seq = [ [legend[i]] for i in seq_list[1:] ]
# Batches data into a batch of size 1:
x = np.zeros((1, seq_length, 1))
y = np.zeros((1, vocab_size))
x[0,:] = train_seq
y[0,:] = to_categorical( train_target, num_classes=vocab_size )
# Online training:
model.fit(x=x, y=y, batch_size=1, epochs=1, verbose=0)
# Now that one training step is done, make a prediction:
x_pred = np.zeros((1, seq_length, 1))
x[0,:] = pred_seq
predicted_onehot = model.predict(x_pred)
# Avoids "index out of range" erros when the LSTM vocab is still being built:
predicted_index = min(np.argmax(predicted_onehot), len(vocab)-1)
predicted_number = vocab[ predicted_index ]
# Reverts deque length to seq_length:
sequence.popleft()
最后,这是一个示例输出:
HIT! Current hit rate: 34.753665869071725 (predicted: 4, sequence was: deque([4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))
HIT! Current hit rate: 34.75566735175926 (predicted: 4, sequence was: deque([4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))
Predicted 4 but it was 8
HIT! Current hit rate: 34.75660255820374 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4]))
HIT! Current hit rate: 34.758603766640086 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4]))
HIT! Current hit rate: 34.7606048523142 (predicted: 4, sequence was: deque([4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4]))
HIT! Current hit rate: 34.76260581523739 (predicted: 4, sequence was: deque([4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4]))
HIT! Current hit rate: 34.76460665542095 (predicted: 4, sequence was: deque([4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))
HIT! Current hit rate: 34.76660737287616 (predicted: 4, sequence was: deque([4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))
Predicted 4 but it was 28
HIT! Current hit rate: 34.767541707556425 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4]))
出了什么问题?
非常感谢您。
【问题讨论】:
-
我的猜测是 LSTM 最有效的行为是总是预测 4s,因为大约 90% 的数字是相同的。尝试让它预测正弦波或类似的东西,你应该会看到它在不改变架构的情况下快速学习它
-
感谢您的评论。但是,我刚刚尝试了循环序列 4、8、4、8、4、8、4、28。一开始,LSTM 在预测 4 和 8 之间有些分裂,但过了一段时间它才开始预测 4一直都是。