【问题标题】:Seaborn heatmap confusion matrix display not displaying as expectedSeaborn 热图混淆矩阵显示未按预期显示
【发布时间】:2020-12-08 20:14:31
【问题描述】:

请指导我了解混淆矩阵的热图显示。我尝试了不同的无花果大小,但没有得到正确的显示。下面是我的代码和屏幕截图

def show_confusion_matrix(test_labels,predictions):
    confusion=sk_metrics.confusion_matrix(np.argmax(test_labels,axis=1),np.argmax(predictions,axis=1))
    confusion_normalized=confusion.astype('float')/confusion.sum(axis=1)
    #confusion_normalized=confusion_matrix(np.argmax(y_test,axis=1),np.argmax(predictions,axis=1))
    axis_labels=list(uniquelabel) ## unique labels has 120 dog breed names
    fig,ax=plt.subplots(figsize=(30,70))
    ax=sns.heatmap(confusion_normalized,xticklabels=axis_labels,yticklabels=axis_labels,
                   linewidths=0.10,cmap='Blues',annot=True,fmt='.2f',square=True)
    plt.title('Confusion_matrix')
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")

show_confusion_matrix(y_test,predictions)  

【问题讨论】:

  • 截图不可用。
  • 现已添加,请查看说明链接,谢谢

标签: python seaborn heatmap


【解决方案1】:

我发现的第一个问题是字体大小

from seaborn import set
from seaborn import set_style

set(font_scale=1.8)
set_style("darkgrid")

或者你可以设置样式whitegrid

你有很多功能,所以我建议你用一个简单的方法来应用掩码。

from numpy import zeros_like
from numpy import triu_indices_from

mask = zeros_like(confusion_normalized)
mask[triu_indices_from(mask)] = True

您需要使用confusion_normalized,因为您要绘制归一化的混淆矩阵。

from seaborn import axes_style
from matplotlib.pyplot import subplots

with axes_style("white"):
    f, ax = subplots(figsize=(15, 15))
    ax = heatmap(confusion_normalized, 
                 annot=True, 
                 mask=mask, 
                 vmax=1,
                 vmin=0,
                 square=True, 
                 cmap="YlGnBu",
                 linewidths=1.5, 
                 annot_kws={"size": 18})
    
savefig('heatmap.png')

An Example output

【讨论】:

  • 感谢@Ahmet,以下代码修改未按预期工作
  • 好吧,我不能 100% 准确,因为我没有您的数据。错误是什么?
  • 不是错误,但显示非常狭窄,显示与您的输出不一样,我有大约 120 个功能
  • 我会建议更改以下参数;减少set(font_scale=1.8)annot_kws={"size": 18}),增加f, ax = subplots(figsize=(15, 15))
  • 现在使用以下修改 ```axis_labels=list(uniquelabel) fig,ax=plt.subplots(figsize=(50,50)) ax=sns.heatmap(confusion_normalized, annot=True, xticklabels=axis_labels, yticklabels=axis_labels, mask=mask, square=True, cmap="YlGnBu", linewidths=1.5, fmt='.2f', annot_kws={"size":12}) print(ax.get_xlim() ) print(ax.get_ylim()) ax.set_ylim(100, 0) ax.set_xlim(0,100) ```
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2020-03-09
  • 2020-07-25
  • 1970-01-01
  • 2018-01-02
  • 1970-01-01
  • 2020-05-05
  • 2021-08-15
相关资源
最近更新 更多