【发布时间】:2019-06-21 13:35:21
【问题描述】:
我正在尝试使用 Keras 使用 LSTM Autoencoder 训练模型来重建我提供给模型的输入,并且我在解码部分后获得的结果中出现 NaN 错误。这是我的代码;
# lstm autoencoder recreate sequence
from numpy import array
import numpy as np
from keras.models import Sequential
from keras.layers import LSTM
from keras.layers import Dense
from keras.layers import RepeatVector
from keras.layers import TimeDistributed
from keras.utils import plot_model
import pandas as pd
df = pd.read_csv('flight_data.csv',sep=',',header=None)
data = df.to_numpy()
print(data.shape)
# define input sequence
sequence1 = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
sequence2 = array([0.2, 0.4, 0.6, 0.4, 1.0, 1.2, 1.4, 1.6, 1.8])
# reshape input into [samples, timesteps, features]
n_in = 100
data = data[73666:,:]
sequence = data.reshape((1,100,24))
print(sequence)
# define model
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(n_in,24)))
model.add(RepeatVector(n_in))
model.add(LSTM(100, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(24)))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit(sequence, sequence, epochs=300, verbose=0)
plot_model(model, show_shapes=True, to_file='reconstruct_lstm_autoencoder.png')
# demonstrate recreation
yhat = model.predict(sequence, verbose=0)
print(yhat)
我得到的输出是;
[[[9.46687355e+14 1.00000000e+01 4.42748822e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[9.46687355e+14 1.00000000e+01 4.42748822e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[9.46687355e+14 1.00000000e+01 4.42748823e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]
...
[9.46687359e+14 1.00000000e+01 4.42748824e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[9.46687359e+14 1.00000000e+01 4.42748824e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[9.46687359e+14 1.00000000e+01 4.42748825e+08 ... 0.00000000e+00
0.00000000e+00 0.00000000e+00]]]
[[[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
...
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]
[nan nan nan ... nan nan nan]]]
哪个部分可能导致问题?我该怎么办?
【问题讨论】:
标签: python keras lstm nan autoencoder