【问题标题】:CNN- Trying to run a confusion matrix using seaborn.heatmapCNN-尝试使用 seaborn.heatmap 运行混淆矩阵
【发布时间】:2019-06-13 15:43:58
【问题描述】:

在我的 CNN 模型运行后,我一直在尝试运行混淆矩阵。

我的模型正在对狗/兔子进行分类。

以下是我所做的:

我将每个班级(狗/兔子)的照片放在两个文件夹内的单独文件夹中:培训和测试。

训练目录-> 兔子目录-> 兔子图片

训练目录->小狗目录->小狗图片

测试目录-> 兔子目录-> 兔子图片

测试目录->小狗目录->小狗图片

我使用以下代码从文件夹中获取图像:

training_data = train_datagen.flow_from_directory('./images/train',
                                             target_size = (28, 28),
                                             batch_size = 86,
                                             class_mode = 'binary',
                                             color_mode='rgb',
                                             classes=None)


test_data = test_datagen.flow_from_directory('./images/test',
                                        target_size = (28, 28),
                                        batch_size = 86,
                                        class_mode = 'binary',
                                        color_mode='rgb',
                                        classes=None)

我使用以下代码将图像分离到训练/验证中。

data_generator = ImageDataGenerator(
    validation_split=0.2,
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

train_generator = data_generator.flow_from_directory(
    './images/train',
    target_size = (28, 28),
    batch_size = 86,
    class_mode = 'binary',
    color_mode='rgb',
    classes=None, subset="training"
)

validation_generator = data_generator.flow_from_directory(
    './images/train',
    target_size = (28, 28),
    batch_size = 86,
    class_mode = 'binary',
    color_mode='rgb',
    classes=None, subset="validation"
)

history=classifier.fit_generator(
    train_generator,
    steps_per_epoch = (8000 / 86),
    epochs = 2,
    validation_data = validation_generator,
    validation_steps = 8000/86,
    callbacks=[learning_rate_reduction]
)

当我尝试运行 confusion_matrix(validation_data) 时,我收到此错误:

TypeError: confusion_matrix() missing 1 required positional argument: 'y_pred'

当我跑步时

#Confusion matrix
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

# Predict the values from the validation dataset
Y_pred = classifier.predict(training_data)
# Convert predictions classes to one hot vectors 
Y_pred_classes = np.argmax(Y_pred,axis = 1) 
# Convert validation observations to one hot vectors
Y_true = np.argmax(training_data,axis = 1) 
# compute the confusion matrix
confusion_mtx = confusion_matrix(Y_true, Y_pred_classes) 
# plot the confusion matrix
plot_confusion_matrix(confusion_mtx, classes = range(10))

sns.heatmap(confusion_mtx, annot=True, fmt='d')

我收到以下错误

AttributeError: 'DirectoryIterator' object has no attribute 'ndim'

【问题讨论】:

    标签: python-3.x machine-learning deep-learning conv-neural-network data-science


    【解决方案1】:

    据我了解,您希望使用混淆矩阵和热图来验证您的分类器模型。 我还对垃圾邮件文本分类进行了验证,所以这是你可以做的,

    对于混淆矩阵,

    from sklearn.metrics import confusion_matrix
    conf_mat = confusion_matrix(y_test, y_pred)
    print(conf_mat)
    

    对于热图,

    import seaborn as sns
    conf_mat = confusion_matrix(y_test, y_pred)
    conf_mat_normalized = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]
    sns.heatmap(conf_mat_normalized)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    

    仅仅意味着混淆矩阵需要两个参数(你的真实标签和预测标签列表)

    希望对你有所帮助。

    【讨论】:

      猜你喜欢
      • 2020-03-09
      • 2018-11-22
      • 1970-01-01
      • 2021-10-04
      • 1970-01-01
      • 2020-08-10
      • 2022-01-04
      • 2020-09-20
      • 1970-01-01
      相关资源
      最近更新 更多