【发布时间】:2020-03-08 13:39:39
【问题描述】:
我想评估模型的准确性,但还要实现 cifar10 数据集的所有 10 个类的混淆矩阵,我收到此错误消息“检查输入时出错:预期 conv2d_9_input 有 4 个维度,但得到了数组形状 (10000, 10)"
def run_test_harness():
# load dataset
trainX, trainY, testX, testY = load_dataset()
# prepare pixel data
trainX, testX = prep_pixels(trainX, testX)
# define model
model = define_model()
# fit model
history = model.fit(trainX, trainY, epochs=100, batch_size=64, validation_data=(testX, testY), verbose=0)
# fig
y_pred=model.predict_classes(testY)
con_mat = tf.math.confusion_matrix(labels=y_true, predictions=y_pred).numpy()
con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)
con_mat_df = pd.DataFrame(con_mat_norm, index = classes, columns = classes)
figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
accuracy, precision, recall = model.evaluate(testX, testY, verbose=0)
print ("recall")
print ('> %.3f' % (recall * 100.0))
print ("accuracy")
print('> %.3f' % (accuracy * 100.0))
print ("precision")
print('> %.3f' % (precision * 100.0))
# 学习曲线 #结束
accuracy, precision, recall = model.evaluate(testX, testY, verbose=0)
print ("recall")
print ('> %.3f' % (recall * 100.0))
print ("accuracy")
print('> %.3f' % (accuracy * 100.0))
print ("precision")
print('> %.3f' % (precision * 100.0))
这是 Cnn 上的实现
定义定义模型():
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same', input_shape=(32, 32, 3)))
model.add(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(10, activation='softmax'))
# compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',precision_m, recall_m])
return model
【问题讨论】:
-
prep_pixel_caode 在哪里?看来您正在将 test_Y 提供给模型。问题可能出在您的 load_dataset 或 prepr_pixel 上?你也可以发布它们吗,据我所知 cifar10.load_data() 返回两个可以像使用 Train_x、train_y、= load_dataset 一样提取的元组
标签: keras conv-neural-network confusion-matrix