【发布时间】:2021-10-09 15:49:25
【问题描述】:
我正在使用 Google Colab 使用 Keras 构建 CNN。该数据集包含 3 个类别,每个类别的图像数量相同。图片在我的 Google 云端硬盘中组织为
Images:
-- class 1
-- class 2
-- class 3
读取数据和创建 CNN 的代码在这里:
batch_size = 30
data = ImageDataGenerator(rescale=1. / 255,
validation_split=0.2)
training_data = data.flow_from_directory('/content/drive/My Drive/Data/Images',
target_size=(200, 200), shuffle=True, batch_size = batch_size,
class_mode='categorical', subset='training')
test_data = data.flow_from_directory('/content/drive/My Drive/Data/Images',
target_size=(200, 200), batch_size = batch_size, shuffle=False,
class_mode='categorical', subset='validation')
numBatchTest = ceil(len(test_data.filenames) / (1.0 * batch_size)) # 1.0 to avoid integer division
numBatchTrain = ceil(len(training_data.filenames) / (1.0 * batch_size)) # 1.0 to avoid integer division
numClasses = 3
Classifier=Sequential()
Classifier.add(Conv2D(32, kernel_size=(5, 5), input_shape=(200, 200, 3)))
Classifier.add(BatchNormalization())
Classifier.add(Activation('relu'))
Classifier.add(MaxPooling2D(pool_size=(2,2)))
Classifier.add(Dropout(0.2))
Classifier.add(Conv2D(64, kernel_size=(3, 3)))
Classifier.add(BatchNormalization())
Classifier.add(Activation('relu'))
Classifier.add(MaxPooling2D(pool_size=(2,2)))
Classifier.add(Dropout(0.2))
Classifier.add(Flatten())
Classifier.add(Dense(64, activation='relu'))
Classifier.add(Dense(32, activation='relu'))
Classifier.add(Dense(16, activation='relu'))
Classifier.add(Dense(8, activation='relu'))
Classifier.add(Dense(numClasses, activation='softmax'))
我训练网络并使用测试数据作为验证:
MyEpochs = 150
Classifier.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(learning_rate=0.01),
metrics=['accuracy'])
Classifier.fit(training_data,
batch_size = 30,
epochs = MyEpochs,
validation_data=test_data,
shuffle = 1)
训练输出的准确率和验证准确率均在 90% 以上:
Epoch 135/150
4/4 [==============================] - 0s 123ms/step - loss: 0.0759 - accuracy: 0.9750 - val_loss: 0.1891 - val_accuracy: 0.9667
Epoch 136/150
4/4 [==============================] - 0s 124ms/step - loss: 0.1153 - accuracy: 0.9583 - val_loss: 0.2348 - val_accuracy: 0.9333
Epoch 137/150
4/4 [==============================] - 1s 134ms/step - loss: 0.1059 - accuracy: 0.9417 - val_loss: 0.1893 - val_accuracy: 0.9667
Epoch 138/150
4/4 [==============================] - 0s 122ms/step - loss: 0.0689 - accuracy: 0.9833 - val_loss: 0.1991 - val_accuracy: 0.9667
Epoch 139/150
4/4 [==============================] - 1s 131ms/step - loss: 0.0716 - accuracy: 0.9750 - val_loss: 0.2175 - val_accuracy: 0.9333
Epoch 140/150
4/4 [==============================] - 0s 125ms/step - loss: 0.1118 - accuracy: 0.9417 - val_loss: 0.2466 - val_accuracy: 0.9333
Epoch 141/150
4/4 [==============================] - 1s 126ms/step - loss: 0.1046 - accuracy: 0.9417 - val_loss: 0.2351 - val_accuracy: 0.9333
Epoch 142/150
4/4 [==============================] - 0s 120ms/step - loss: 0.0988 - accuracy: 0.9417 - val_loss: 0.1994 - val_accuracy: 0.9333
Epoch 143/150
4/4 [==============================] - 0s 124ms/step - loss: 0.0803 - accuracy: 0.9500 - val_loss: 0.1910 - val_accuracy: 0.9667
Epoch 144/150
4/4 [==============================] - 0s 124ms/step - loss: 0.0786 - accuracy: 0.9750 - val_loss: 0.1908 - val_accuracy: 0.9667
Epoch 145/150
4/4 [==============================] - 0s 124ms/step - loss: 0.0947 - accuracy: 0.9500 - val_loss: 0.4854 - val_accuracy: 0.8667
Epoch 146/150
4/4 [==============================] - 1s 128ms/step - loss: 0.2091 - accuracy: 0.9000 - val_loss: 0.1858 - val_accuracy: 0.9333
Epoch 147/150
4/4 [==============================] - 0s 124ms/step - loss: 0.0838 - accuracy: 0.9417 - val_loss: 0.1779 - val_accuracy: 0.9667
Epoch 148/150
4/4 [==============================] - 1s 128ms/step - loss: 0.0771 - accuracy: 0.9667 - val_loss: 0.1897 - val_accuracy: 0.9667
Epoch 149/150
4/4 [==============================] - 0s 120ms/step - loss: 0.0869 - accuracy: 0.9667 - val_loss: 0.1982 - val_accuracy: 0.9667
Epoch 150/150
4/4 [==============================] - 0s 119ms/step - loss: 0.0809 - accuracy: 0.9500 - val_loss: 0.2615 - val_accuracy: 0.9333
为了测试模型,我再次预测训练数据:
training_data.reset()
test_data.reset()
predicted_scores = Classifier.predict(training_data, verbose=1)
predicted_labels = predicted_scores.argmax(axis=1)
train_labels = []
training_data.reset()
for i in range(0,numBatchTrain):
train_labels = np.append(train_labels, (training_data[i][1]).argmax(axis = 1))
print(train_labels)
print(predicted_labels)
acc_score = accuracy_score(train_labels, predicted_labels)
CFM = confusion_matrix(train_labels, predicted_labels)
print("\n", "Accuracy: " + str(format(acc_score,'.3f')))
print("\n", "CFM: \n", confusion_matrix(train_labels, predicted_labels))
print("\n", "Classification report: \n", classification_report(train_labels, predicted_labels))
我在获取 training_data 和 testing_data 的标签时遇到了一些麻烦,当我刚刚使用 training_data.labels 时,它们的顺序似乎与图像不同,这就是为什么我循环遍历批次以附加标签。当我只使用training_data.labels 时,结果同样糟糕。该代码的输出是:
4/4 [==============================] - 0s 71ms/step
[0. 2. 2. 0. 0. 1. 2. 2. 2. 1. 0. 0. 0. 1. 2. 0. 2. 0. 0. 1. 1. 1. 0. 0.
0. 2. 2. 0. 1. 2. 0. 2. 1. 1. 2. 2. 0. 1. 0. 2. 0. 1. 1. 0. 2. 2. 0. 2.
2. 2. 1. 2. 1. 0. 2. 2. 1. 2. 1. 0. 1. 2. 0. 1. 1. 1. 1. 2. 0. 0. 1. 1.
1. 1. 1. 1. 2. 0. 0. 2. 2. 0. 1. 1. 1. 0. 2. 1. 2. 1. 2. 1. 1. 2. 0. 2.
2. 0. 0. 2. 1. 0. 2. 0. 0. 1. 1. 2. 0. 0. 1. 1. 0. 0. 1. 2. 0. 2. 0. 2.]
[2 2 2 0 1 1 0 1 1 0 0 2 0 2 0 0 1 2 2 2 2 0 0 2 1 0 2 2 1 1 0 2 1 1 0 0 1
0 1 0 2 2 2 1 1 1 0 2 0 1 0 0 2 0 0 0 2 0 1 2 2 1 0 2 2 0 1 0 2 2 0 2 0 0
1 1 2 2 2 0 2 2 1 0 2 1 2 1 0 1 2 2 0 2 0 2 0 0 1 1 1 1 2 2 0 0 1 1 1 2 0
0 1 0 1 0 2 0 0 0]
Accuracy: 0.333
CFM:
[[14 10 16]
[13 14 13]
[18 10 12]]
Classification report:
precision recall f1-score support
0.0 0.31 0.35 0.33 40
1.0 0.41 0.35 0.38 40
2.0 0.29 0.30 0.30 40
accuracy 0.33 120
macro avg 0.34 0.33 0.33 120
weighted avg 0.34 0.33 0.33 120
在训练过程中,训练和验证数据的准确率非常高,但在测试时,使用与训练相同的数据,准确率只有 33.3%。
我认为,这里的问题是,类标签在某个地方混淆了,但我不知道如何解决它。数据集本身很简单,在 Matlab 中构建相同的 CNN,我对训练和测试数据都获得 100% 的准确率,但我无法让它在 Python 中运行。
有没有人有建议,如何让它在 Python 中运行?
【问题讨论】:
标签: python keras conv-neural-network