【发布时间】:2020-06-20 15:09:24
【问题描述】:
我正在使用 fashion_mnist 图像数据库(60,000 个小正方形 28×28 像素灰度图像),我正在尝试将 CNN-LSTM 应用于级联,这是我正在使用的代码:
from tensorflow.keras.datasets import fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
print("Shape of x_train: {}".format(x_train.shape))
print("Shape of y_train: {}".format(y_train.shape))
print()
print("Shape of x_test: {}".format(x_test.shape))
print("Shape of y_test: {}".format(y_test.shape))
# define CNN model
model = Sequential()
model.add(TimeDistributed(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=(60000,28,28))))
model.add(TimeDistributed(Conv2D(64, (3, 3), activation='relu')))
model.add(TimeDistributed(MaxPooling2D(pool_size=(2, 2))))
model.add(TimeDistributed((Dropout(0.25))))
model.add(TimeDistributed(Flatten()))
## LSTM
model.add(LSTM(200, activation='relu', return_sequences=True))
model.add(Dense(128, activation='relu'))
model.compile(loss='categorical_crossentropy', optimizer='adam',
metrics=['accuracy'])
##fitting model
model.fit(x_train,y_train,epochs=5)
test_loss, test_acc=model.evaluate(x_test,y_test)
print('Loss: {0}-Acc:{1}')
print(test_acc)
我在运行试衣线后出现错误,谁能帮我解决这个错误。
【问题讨论】:
标签: python-3.x image tensorflow keras jupyter-notebook