【问题标题】:Seaborn heatmap fit annotation text to cellSeaborn 热图使注释文本适合单元格
【发布时间】:2021-12-05 15:32:25
【问题描述】:

我有显示混淆矩阵的代码。在每个单元格中,首先显示准确度,然后在其下方显示正确预测样本数/总样本数。现在我想显示每个单元格内的所有文本。例如,第一个单元格应在精度下显示 186/208。 如何在单元格内显示注释的全文?我试图减小字体大小,但没有奏效。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def cm_analysis(cm, labels, figsize=(20,15)):
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float)
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.2f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.2f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Groundtruth labels'
    cm.columns.name = 'Predicted labels'
    fig, ax = plt.subplots(figsize=figsize)
    ax.axhline(color='black')

    g =sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax, cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
    g.set_xticklabels(g.get_xticklabels(), rotation = 45)
    sns.set(font_scale=1.1)
    plt.savefig("filename.png")

normalised_confusion_matrix  = np.array(
[[186,3,0,1,2,0,3,3,7,1,2,0,0],
 [5,9,1,0,3,0,0,0,0,0,0,0,1],
 [0,0,49,3,0,0,0,0,1,0,0,0,6],
 [1,0,6,89,0,0,0,0,1,1,1,0,1],
 [3,7,0,0,50,0,0,0,6,0,1,0,0],
 [1,0,0,0,0,9,0,1,0,0,0,0,0],
 [3,0,1,0,0,0,54,0,0,0,3,0,0],
 [2,0,0,0,0,0,2,7,0,0,0,0,0],
 [3,0,0,0,2,1,2,0,53,2,4,0,0],
 [0,0,0,1,0,1,0,0,1,7,0,1,0],
 [1,1,0,0,1,0,1,0,3,0,52,0,0],
 [1,0,0,0,0,0,0,0,1,0,0,5,0],
 [0,0,11,2,0,0,0,0,0,0,0,0,26]]
)

classes = ['Assemble system','Consult sheets','Picking in front','Picking left','Put down component','Put down measuring rod','Put down screwdriver','Put down subsystem','Take component','Take measuring rod','Take screwdriver','Take subsystem','Turn sheets']

    
cm_analysis(cm= normalised_confusion_matrix, labels = classes)

【问题讨论】:

    标签: python seaborn


    【解决方案1】:

    主要问题是将annot 数组创建为str 类型而不是object(所以,annot = np.empty_like(cm).astype(object))。拥有str 类型的它会导致奇怪的错误,因为 numpy 字符串有一些内置的最大长度。 (另见this post。)

    由于您在cm_sum[i] 中只使用一个索引,因此最好不要在cm_sum = np.sum(cm, axis=1, keepdims=False) (docs) 中“保留维度”。

    另外,请注意,对于百分比,您需要乘以 100。(创建格式化字符串的现代方法是使用 f-strings:annot[i, j] = f'{p*100:.2f}%\n{c}/{s}')。

    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    def cm_analysis(cm, labels, figsize=(20, 15)):
        cm_sum = np.sum(cm, axis=1, keepdims=False)
        cm_perc = cm / cm_sum.astype(float)
        annot = np.empty_like(cm).astype(object)
        nrows, ncols = cm.shape
        for i in range(nrows):
            for j in range(ncols):
                c = cm[i, j]
                p = cm_perc[i, j]
                if i == j:
                    s = cm_sum[i]
                    annot[i, j] = f'{p*100:.1f}%\n{c}/{s}'
                elif c == 0:
                    annot[i, j] = ''
                else:
                    annot[i, j] = f'{p*100:.1f}%\n{c}'
        cm = pd.DataFrame(cm, index=labels, columns=labels)
        cm.index.name = 'Groundtruth labels'
        cm.columns.name = 'Predicted labels'
        fig, ax = plt.subplots(figsize=figsize)
        ax.axhline(color='black')
    
        g = sns.heatmap(cm, cmap="BuPu", annot_kws={"weight": "bold"}, annot=annot, fmt='', ax=ax,
                        cbar_kws={'label': 'Number of samples'}, linewidths=0.1, linecolor='black')
        g.set_xticklabels(g.get_xticklabels(), rotation=45)
        sns.set(font_scale=1.1)
        plt.savefig("filename.png")
    
    normalised_confusion_matrix = np.array(
        [[186, 3, 0, 1, 2, 0, 3, 3, 7, 1, 2, 0, 0],
         [5, 9, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1],
         [0, 0, 49, 3, 0, 0, 0, 0, 1, 0, 0, 0, 6],
         [1, 0, 6, 89, 0, 0, 0, 0, 1, 1, 1, 0, 1],
         [3, 7, 0, 0, 50, 0, 0, 0, 6, 0, 1, 0, 0],
         [1, 0, 0, 0, 0, 9, 0, 1, 0, 0, 0, 0, 0],
         [3, 0, 1, 0, 0, 0, 54, 0, 0, 0, 3, 0, 0],
         [2, 0, 0, 0, 0, 0, 2, 7, 0, 0, 0, 0, 0],
         [3, 0, 0, 0, 2, 1, 2, 0, 53, 2, 4, 0, 0],
         [0, 0, 0, 1, 0, 1, 0, 0, 1, 7, 0, 1, 0],
         [1, 1, 0, 0, 1, 0, 1, 0, 3, 0, 52, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 5, 0],
         [0, 0, 11, 2, 0, 0, 0, 0, 0, 0, 0, 0, 26]]
    )
    
    classes = ['Assemble system', 'Consult sheets', 'Picking in front', 'Picking left', 'Put down component',
               'Put down measuring rod', 'Put down screwdriver', 'Put down subsystem', 'Take component',
               'Take measuring rod', 'Take screwdriver', 'Take subsystem', 'Turn sheets']
    
    cm_analysis(cm=normalised_confusion_matrix, labels=classes)
    

    【讨论】:

    • 完美。谢谢!
    猜你喜欢
    • 1970-01-01
    • 2020-11-28
    • 2021-12-23
    • 2022-07-09
    • 2016-01-14
    • 1970-01-01
    • 1970-01-01
    • 2020-04-19
    相关资源
    最近更新 更多