【问题标题】:Plotting multiple confusion matrix side by side [duplicate]并排绘制多个混淆矩阵[重复]
【发布时间】:2020-09-01 15:11:06
【问题描述】:

我是新来的。这是我的第一个问题,希望能得到专家的解答。我有 5 个分类器模型,我正在尝试绘制它们的混淆矩阵。

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import collections

classifiers = {
    "Naive Bayes": GaussianNB(),
    "LogisiticRegression": LogisticRegression(),
    "KNearest": KNeighborsClassifier(),
    "Support Vector Classifier": SVC(),
    "DecisionTreeClassifier": DecisionTreeClassifier(),
}

然后

from sklearn.metrics import confusion_matrix
for key, classifier in classifiers.items(): 
    y_pred = classifier.fit(X_train, y_train).predict(X_test)
    cf_matrix=confusion_matrix(y_test, y_pred)
    print(cf_matrix)

这给了我

现在我正在尝试用下面的代码绘制它们,但图中没有显示任何数据。

fig, axn = plt.subplots(1,5, sharex=True, sharey=True)
cbar_ax = fig.add_axes([.91, .3, .03, .4])

for i, ax in enumerate(axn.flat):
    sns.heatmap(cf_matrix, ax=ax,
                cbar=i == 0,
                vmin=0, vmax=1,
                cbar_ax=None if i else cbar_ax)

fig.tight_layout(rect=[0, 0, .9, 1])

有人可以帮我完成这项工作吗?

【问题讨论】:

    标签: python matplotlib plot scikit-learn seaborn


    【解决方案1】:

    sklearnconfusion_matrix 上提供绘图功能。 有两种方法,

    我在这里使用了第二种方式,因为在第一种方式中删除颜色条非常冗长(有多个颜色条看起来很混乱)。

    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    from sklearn.datasets import load_iris
    from sklearn.naive_bayes import GaussianNB
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.ensemble import RandomForestClassifier
    
    classifiers = {
        "Naive Bayes": GaussianNB(),
        "LogisiticRegression": LogisticRegression(),
        "KNearest": KNeighborsClassifier(),
        "Support Vector Classifier": SVC(),
        "DecisionTreeClassifier": DecisionTreeClassifier(),
    }
    
    
    iris = load_iris()
    X, y = iris.data, iris.target
    
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    
    
    f, axes = plt.subplots(1, 5, figsize=(20, 5), sharey='row')
    
    for i, (key, classifier) in enumerate(classifiers.items()):
        y_pred = classifier.fit(X_train, y_train).predict(X_test)
        cf_matrix = confusion_matrix(y_test, y_pred)
        disp = ConfusionMatrixDisplay(cf_matrix,
                                      display_labels=iris.target_names)
        disp.plot(ax=axes[i], xticks_rotation=45)
        disp.ax_.set_title(key)
        disp.im_.colorbar.remove()
        disp.ax_.set_xlabel('')
        if i!=0:
            disp.ax_.set_ylabel('')
    
    f.text(0.4, 0.1, 'Predicted label', ha='left')
    plt.subplots_adjust(wspace=0.40, hspace=0.1)
    
    
    f.colorbar(disp.im_, ax=axes)
    plt.show()
    
    

    【讨论】:

    • 酷,不知道这个功能..哈哈从来没有可视化它
    • 它是最近添加的。大声笑,你只需要在做演示时
    【解决方案2】:

    您需要将混淆矩阵存储在某处,因此如果我使用示例数据集:

    import pandas as pd
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    data = load_breast_cancer()
    scaler = StandardScaler()
    
    X_df = pd.DataFrame(data.data, columns=data.feature_names)
    X_df = scaler.fit_transform(X_df)
    y_df = pd.DataFrame(data.target, columns=['target'])
    
    X_train, X_test, y_train, y_test = train_test_split(X_df, y_df, test_size=0.2, random_state=11)
    

    并将其存储在类似的字典中:

    from sklearn.metrics import confusion_matrix
    cf_matrix = dict.fromkeys(classifiers.keys())
    for key, classifier in classifiers.items(): 
        y_pred = classifier.fit(X_train, y_train.values.ravel()).predict(X_test)
        cf_matrix[key]=confusion_matrix(y_test, y_pred)
    

    然后你可以绘制它:

    fig, axn = plt.subplots(1,5, sharex=True, sharey=True,figsize=(12,2))
    
    for i, ax in enumerate(axn.flat):
        k = list(cf_matrix)[i]
        sns.heatmap(cf_matrix[k], ax=ax,cbar=i==4)
        ax.set_title(k,fontsize=8)
    

    【讨论】:

    • 非常感谢。我按照您的建议进行了更改,结果如下
    • 嗨@hadi0815,欢迎来到stackoverflow。您发布了一个问题,我相信我已经用良好的可重现代码回答了它。如果您还有其他问题超出此问题,请将其作为单独的问题发布
    猜你喜欢
    • 2016-01-31
    • 2020-07-15
    • 2016-06-04
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-10-17
    • 2020-10-24
    相关资源
    最近更新 更多