【问题标题】:How to plot scikit learn classification report?如何绘制 scikit learn 分类报告?
【发布时间】:2015-03-27 20:33:36
【问题描述】:

是否可以使用 matplotlib scikit-learn 分类报告进行绘图?假设我这样打印分类报告:

print '\n*Classification Report:\n', classification_report(y_test, predictions)
    confusion_matrix_graph = confusion_matrix(y_test, predictions)

我得到:

Clasification Report:
             precision    recall  f1-score   support

          1       0.62      1.00      0.76        66
          2       0.93      0.93      0.93        40
          3       0.59      0.97      0.73        67
          4       0.47      0.92      0.62       272
          5       1.00      0.16      0.28       413

avg / total       0.77      0.57      0.49       858

如何“绘制”avobe 图表?

【问题讨论】:

    标签: python numpy matplotlib scikit-learn


    【解决方案1】:

    你可以这样做:

    import matplotlib.pyplot as plt
    
    cm =  [[0.50, 1.00, 0.67],
           [0.00, 0.00, 0.00],
           [1.00, 0.67, 0.80]]
    labels = ['class 0', 'class 1', 'class 2']
    fig, ax = plt.subplots()
    h = ax.matshow(cm)
    fig.colorbar(h)
    ax.set_xticklabels([''] + labels)
    ax.set_yticklabels([''] + labels)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Ground truth')
    

    【讨论】:

    • 感谢您的帮助,我编辑了问题,因为我跳过了我正在使用的指标。有什么方法可以查看精度、召回率、f1 分数、支持指标发生了什么?
    • 我注意到这个公认的答案是可视化混淆矩阵而不是分类报告。
    【解决方案2】:

    我刚刚为此编写了一个函数plot_classification_report()。希望能帮助到你。 该函数将分类报告函数的输出作为参数并绘制分数。这是功能。

    def plot_classification_report(cr, title='Classification report ', with_avg_total=False, cmap=plt.cm.Blues):
    
        lines = cr.split('\n')
    
        classes = []
        plotMat = []
        for line in lines[2 : (len(lines) - 3)]:
            #print(line)
            t = line.split()
            # print(t)
            classes.append(t[0])
            v = [float(x) for x in t[1: len(t) - 1]]
            print(v)
            plotMat.append(v)
    
        if with_avg_total:
            aveTotal = lines[len(lines) - 1].split()
            classes.append('avg/total')
            vAveTotal = [float(x) for x in t[1:len(aveTotal) - 1]]
            plotMat.append(vAveTotal)
    
    
        plt.imshow(plotMat, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        x_tick_marks = np.arange(3)
        y_tick_marks = np.arange(len(classes))
        plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
        plt.yticks(y_tick_marks, classes)
        plt.tight_layout()
        plt.ylabel('Classes')
        plt.xlabel('Measures')
    

    对于您提供的示例分类报告。这是代码和输出。

    sampleClassificationReport = """             precision    recall  f1-score   support
    
              1       0.62      1.00      0.76        66
              2       0.93      0.93      0.93        40
              3       0.59      0.97      0.73        67
              4       0.47      0.92      0.62       272
              5       1.00      0.16      0.28       413
    
    avg / total       0.77      0.57      0.49       858"""
    
    
    plot_classification_report(sampleClassificationReport)
    

    这里是如何将它与 sklearn 分类报告输出一起使用:

    from sklearn.metrics import classification_report
    classificationReport = classification_report(y_true, y_pred, target_names=target_names)
    
    plot_classification_report(classificationReport)
    

    使用此功能,您还可以将“平均/总计”结果添加到绘图中。要使用它,只需像这样添加一个参数with_avg_total

    plot_classification_report(classificationReport, with_avg_total=True)
    

    【讨论】:

    • 部分bug修正:for line in lines[2 : (len(lines) - 3)]: #print(line) t = line.split() # print(t) if(len(t)==0): break
    【解决方案3】:

    扩展Bin的回答:

    import matplotlib.pyplot as plt
    import numpy as np
    
    def show_values(pc, fmt="%.2f", **kw):
        '''
        Heatmap with text in each cell with matplotlib's pyplot
        Source: https://stackoverflow.com/a/25074150/395857 
        By HYRY
        '''
        from itertools import izip
        pc.update_scalarmappable()
        ax = pc.get_axes()
        #ax = pc.axes# FOR LATEST MATPLOTLIB
        #Use zip BELOW IN PYTHON 3
        for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
            x, y = p.vertices[:-2, :].mean(0)
            if np.all(color[:3] > 0.5):
                color = (0.0, 0.0, 0.0)
            else:
                color = (1.0, 1.0, 1.0)
            ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
    
    
    def cm2inch(*tupl):
        '''
        Specify figure size in centimeter in matplotlib
        Source: https://stackoverflow.com/a/22787457/395857
        By gns-ank
        '''
        inch = 2.54
        if type(tupl[0]) == tuple:
            return tuple(i/inch for i in tupl[0])
        else:
            return tuple(i/inch for i in tupl)
    
    
    def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):
        '''
        Inspired by:
        - https://stackoverflow.com/a/16124677/395857 
        - https://stackoverflow.com/a/25074150/395857
        '''
    
        # Plot it out
        fig, ax = plt.subplots()    
        #c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
        c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap)
    
        # put the major ticks at the middle of each cell
        ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
        ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
    
        # set tick labels
        #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
        ax.set_xticklabels(xticklabels, minor=False)
        ax.set_yticklabels(yticklabels, minor=False)
    
        # set title and x/y labels
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)      
    
        # Remove last blank column
        plt.xlim( (0, AUC.shape[1]) )
    
        # Turn off all the ticks
        ax = plt.gca()    
        for t in ax.xaxis.get_major_ticks():
            t.tick1On = False
            t.tick2On = False
        for t in ax.yaxis.get_major_ticks():
            t.tick1On = False
            t.tick2On = False
    
        # Add color bar
        plt.colorbar(c)
    
        # Add text in each cell 
        show_values(c)
    
        # Proper orientation (origin at the top left instead of bottom left)
        if correct_orientation:
            ax.invert_yaxis()
            ax.xaxis.tick_top()       
    
        # resize 
        fig = plt.gcf()
        #fig.set_size_inches(cm2inch(40, 20))
        #fig.set_size_inches(cm2inch(40*4, 20*4))
        fig.set_size_inches(cm2inch(figure_width, figure_height))
    
    
    
    def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
        '''
        Plot scikit-learn classification report.
        Extension based on https://stackoverflow.com/a/31689645/395857 
        '''
        lines = classification_report.split('\n')
    
        classes = []
        plotMat = []
        support = []
        class_names = []
        for line in lines[2 : (len(lines) - 2)]:
            t = line.strip().split()
            if len(t) < 2: continue
            classes.append(t[0])
            v = [float(x) for x in t[1: len(t) - 1]]
            support.append(int(t[-1]))
            class_names.append(t[0])
            print(v)
            plotMat.append(v)
    
        print('plotMat: {0}'.format(plotMat))
        print('support: {0}'.format(support))
    
        xlabel = 'Metrics'
        ylabel = 'Classes'
        xticklabels = ['Precision', 'Recall', 'F1-score']
        yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
        figure_width = 25
        figure_height = len(class_names) + 7
        correct_orientation = False
        heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
    
    
    def main():
        sampleClassificationReport = """             precision    recall  f1-score   support
    
              Acacia       0.62      1.00      0.76        66
              Blossom       0.93      0.93      0.93        40
              Camellia       0.59      0.97      0.73        67
              Daisy       0.47      0.92      0.62       272
              Echium       1.00      0.16      0.28       413
    
            avg / total       0.77      0.57      0.49       858"""
    
    
        plot_classification_report(sampleClassificationReport)
        plt.savefig('test_plot_classif_report.png', dpi=200, format='png', bbox_inches='tight')
        plt.close()
    
    if __name__ == "__main__":
        main()
        #cProfile.run('main()') # if you want to do some profiling
    

    输出:

    更多类的示例 (~40):

    【讨论】:

    • 如果没有itertools,删除“from itertools import izip”,将izip替换为zip。
    • 所述解决方案似乎不适用于当前版本的matplotlibax = pc.get_axes() 行必须更改为 ax = pc.axes
    • 但是为什么要使用 izip 呢?它比 zip 慢,并且与 Python3 不兼容:stackoverflow.com/questions/32659552/…
    • 有没有办法让它与分类报告提供的最新输出一起工作?
    【解决方案4】:

    这是我的简单解决方案,使用 seaborn 热图

    import seaborn as sns
    import numpy as np
    from sklearn.metrics import precision_recall_fscore_support
    import matplotlib.pyplot as plt
    
    y = np.random.randint(low=0, high=10, size=100)
    y_p = np.random.randint(low=0, high=10, size=100)
    
    def plot_classification_report(y_tru, y_prd, figsize=(10, 10), ax=None):
    
        plt.figure(figsize=figsize)
    
        xticks = ['precision', 'recall', 'f1-score', 'support']
        yticks = list(np.unique(y_tru))
        yticks += ['avg']
    
        rep = np.array(precision_recall_fscore_support(y_tru, y_prd)).T
        avg = np.mean(rep, axis=0)
        avg[-1] = np.sum(rep[:, -1])
        rep = np.insert(rep, rep.shape[0], avg, axis=0)
    
        sns.heatmap(rep,
                    annot=True, 
                    cbar=False, 
                    xticklabels=xticks, 
                    yticklabels=yticks,
                    ax=ax)
    
    plot_classification_report(y, y_p)
    

    This is how the plot will look like

    【讨论】:

      【解决方案5】:

      我的解决方案是使用 python 包 Yellowbrick。简而言之,Yellowbrick 将 scikit-learn 与 matplotlib 相结合,为您的模型生成可视化。在几行中,您可以执行上面建议的操作。 http://www.scikit-yb.org/en/latest/api/classifier/classification_report.html

      from sklearn.naive_bayes import GaussianNB
      from yellowbrick.classifier import ClassificationReport
      
      # Instantiate the classification model and visualizer
      bayes = GaussianNB()
      visualizer = ClassificationReport(bayes, classes=classes, support=True)
      
      visualizer.fit(X_train, y_train)  # Fit the visualizer and the model
      visualizer.score(X_test, y_test)  # Evaluate the model on the test data
      visualizer.show()             # Draw/show the data
      

      【讨论】:

        【解决方案6】:

        在这里,您可以获得与Franck Dernoncourt 相同的绘图,但代码要短得多(可以放入单个函数中)。

        import matplotlib.pyplot as plt
        import numpy as np
        import itertools
        
        
        def plot_classification_report(classificationReport,
                                       title='Classification report',
                                       cmap='RdBu'):
        
            classificationReport = classificationReport.replace('\n\n', '\n')
            classificationReport = classificationReport.replace(' / ', '/')
            lines = classificationReport.split('\n')
        
            classes, plotMat, support, class_names = [], [], [], []
            for line in lines[1:]:  # if you don't want avg/total result, then change [1:] into [1:-1]
                t = line.strip().split()
                if len(t) < 2:
                    continue
                classes.append(t[0])
                v = [float(x) for x in t[1: len(t) - 1]]
                support.append(int(t[-1]))
                class_names.append(t[0])
                plotMat.append(v)
        
            plotMat = np.array(plotMat)
            xticklabels = ['Precision', 'Recall', 'F1-score']
            yticklabels = ['{0} ({1})'.format(class_names[idx], sup)
                           for idx, sup in enumerate(support)]
        
            plt.imshow(plotMat, interpolation='nearest', cmap=cmap, aspect='auto')
            plt.title(title)
            plt.colorbar()
            plt.xticks(np.arange(3), xticklabels, rotation=45)
            plt.yticks(np.arange(len(classes)), yticklabels)
        
            upper_thresh = plotMat.min() + (plotMat.max() - plotMat.min()) / 10 * 8
            lower_thresh = plotMat.min() + (plotMat.max() - plotMat.min()) / 10 * 2
            for i, j in itertools.product(range(plotMat.shape[0]), range(plotMat.shape[1])):
                plt.text(j, i, format(plotMat[i, j], '.2f'),
                         horizontalalignment="center",
                         color="white" if (plotMat[i, j] > upper_thresh or plotMat[i, j] < lower_thresh) else "black")
        
            plt.ylabel('Metrics')
            plt.xlabel('Classes')
            plt.tight_layout()
        
        
        def main():
        
            sampleClassificationReport = """             precision    recall  f1-score   support
        
                  Acacia       0.62      1.00      0.76        66
                  Blossom       0.93      0.93      0.93        40
                  Camellia       0.59      0.97      0.73        67
                  Daisy       0.47      0.92      0.62       272
                  Echium       1.00      0.16      0.28       413
        
                avg / total       0.77      0.57      0.49       858"""
        
            plot_classification_report(sampleClassificationReport)
            plt.show()
            plt.close()
        
        
        if __name__ == '__main__':
            main()
        

        【讨论】:

          【解决方案7】:

          如果您只想在 Jupyter 笔记本中将分类报告绘制为条形图,您可以执行以下操作。

          # Assuming that classification_report, y_test and predictions are in scope...
          import pandas as pd
          
          # Build a DataFrame from the classification_report output_dict.
          report_data = []
          for label, metrics in classification_report(y_test, predictions, output_dict=True).items():
              metrics['label'] = label
              report_data.append(metrics)
          
          report_df = pd.DataFrame(
              report_data, 
              columns=['label', 'precision', 'recall', 'f1-score', 'support']
          )
          
          # Plot as a bar chart.
          report_df.plot(y=['precision', 'recall', 'f1-score'], x='label', kind='bar')
          

          这种可视化的一个问题是不平衡的类并不明显,但在解释结果时很重要。表示这一点的一种方法是添加包含样本数量的label 版本(即support):

          # Add a column to the DataFrame.
          report_df['labelsupport'] = [f'{label} (n={support})' 
                                       for label, support in zip(report_df.label, report_df.support)]
          
          # Plot the chart the same way, but use `labelsupport` as the x-axis.
          report_df.plot(y=['precision', 'recall', 'f1-score'], x='labelsupport', kind='bar')
          

          【讨论】:

            【解决方案8】:

            无字符串处理 + sns.heatmap

            以下解决方案使用classification_report 中的output_dict=True 选项获取字典,然后使用seaborn 将热图绘制到从字典创建的数据框。


            import numpy as np
            import seaborn as sns
            from sklearn.metrics import classification_report
            import pandas as pd
            

            生成数据。上课:A,B,C,D,E,F,G,H,I

            true = np.random.randint(0, 10, size=100)
            pred = np.random.randint(0, 10, size=100)
            labels = np.arange(10)
            target_names = list("ABCDEFGHI")
            

            output_dict=True拨打classification_report

            clf_report = classification_report(true,
                                               pred,
                                               labels=labels,
                                               target_names=target_names,
                                               output_dict=True)
            

            从字典中创建一个数据框并绘制它的热图。

            # .iloc[:-1, :] to exclude support
            sns.heatmap(pd.DataFrame(clf_report).iloc[:-1, :].T, annot=True)
            

            【讨论】:

              【解决方案9】:

              这对我的Franck Dernoncourt and Bin 的回答非常有用,但是我遇到了两个问题。

              首先,当我尝试将它与“No hit”或内部有空格的名称等类一起使用时,绘图失败了。
              另一个问题是将此函数与 MatPlotlib 3.* 和 scikitLearn-0.22.* 版本一起使用。所以我做了一些小改动:

              import matplotlib.pyplot as plt
              import numpy as np
              
              def show_values(pc, fmt="%.2f", **kw):
                  '''
                  Heatmap with text in each cell with matplotlib's pyplot
                  Source: https://stackoverflow.com/a/25074150/395857 
                  By HYRY
                  '''
                  pc.update_scalarmappable()
                  ax = pc.axes
                  #ax = pc.axes# FOR LATEST MATPLOTLIB
                  #Use zip BELOW IN PYTHON 3
                  for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
                      x, y = p.vertices[:-2, :].mean(0)
                      if np.all(color[:3] > 0.5):
                          color = (0.0, 0.0, 0.0)
                      else:
                          color = (1.0, 1.0, 1.0)
                      ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
              
              
              def cm2inch(*tupl):
                  '''
                  Specify figure size in centimeter in matplotlib
                  Source: https://stackoverflow.com/a/22787457/395857
                  By gns-ank
                  '''
                  inch = 2.54
                  if type(tupl[0]) == tuple:
                      return tuple(i/inch for i in tupl[0])
                  else:
                      return tuple(i/inch for i in tupl)
              
              
              def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):
                  '''
                  Inspired by:
                  - https://stackoverflow.com/a/16124677/395857 
                  - https://stackoverflow.com/a/25074150/395857
                  '''
              
                  # Plot it out
                  fig, ax = plt.subplots()    
                  #c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
                  c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap, vmin=0.0, vmax=1.0)
              
                  # put the major ticks at the middle of each cell
                  ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
                  ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
              
                  # set tick labels
                  #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
                  ax.set_xticklabels(xticklabels, minor=False)
                  ax.set_yticklabels(yticklabels, minor=False)
              
                  # set title and x/y labels
                  plt.title(title, y=1.25)
                  plt.xlabel(xlabel)
                  plt.ylabel(ylabel)      
              
                  # Remove last blank column
                  plt.xlim( (0, AUC.shape[1]) )
              
                  # Turn off all the ticks
                  ax = plt.gca()    
                  for t in ax.xaxis.get_major_ticks():
                      t.tick1line.set_visible(False)
                      t.tick2line.set_visible(False)
                  for t in ax.yaxis.get_major_ticks():
                      t.tick1line.set_visible(False)
                      t.tick2line.set_visible(False)
              
                  # Add color bar
                  plt.colorbar(c)
              
                  # Add text in each cell 
                  show_values(c)
              
                  # Proper orientation (origin at the top left instead of bottom left)
                  if correct_orientation:
                      ax.invert_yaxis()
                      ax.xaxis.tick_top()       
              
                  # resize 
                  fig = plt.gcf()
                  #fig.set_size_inches(cm2inch(40, 20))
                  #fig.set_size_inches(cm2inch(40*4, 20*4))
                  fig.set_size_inches(cm2inch(figure_width, figure_height))
              
              
              
              def plot_classification_report(classification_report, number_of_classes=2, title='Classification report ', cmap='RdYlGn'):
                  '''
                  Plot scikit-learn classification report.
                  Extension based on https://stackoverflow.com/a/31689645/395857 
                  '''
                  lines = classification_report.split('\n')
                  
                  #drop initial lines
                  lines = lines[2:]
              
                  classes = []
                  plotMat = []
                  support = []
                  class_names = []
                  for line in lines[: number_of_classes]:
                      t = list(filter(None, line.strip().split('  ')))
                      if len(t) < 4: continue
                      classes.append(t[0])
                      v = [float(x) for x in t[1: len(t) - 1]]
                      support.append(int(t[-1]))
                      class_names.append(t[0])
                      plotMat.append(v)
              
              
                  xlabel = 'Metrics'
                  ylabel = 'Classes'
                  xticklabels = ['Precision', 'Recall', 'F1-score']
                  yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
                  figure_width = 10
                  figure_height = len(class_names) + 3
                  correct_orientation = True
                  heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
                  plt.show()
              
              
              

              【讨论】:

                【解决方案10】:
                这对我有用,从上面的最佳答案拼凑而成,另外,我无法发表评论,但感谢所有的这个线程,它帮助了很多!
                def plot_classification_report(cr, title='Classification report ', with_avg_total=False, cmap=plt.cm.Blues):
                    lines = cr.split('\n')
                    classes = []
                    plotMat = []
                    for line in lines[2 : (len(lines) - 6)]: rt
                        t = line.split()
                        classes.append(t[0])
                        v = [float(x) for x in t[1: len(t) - 1]]
                        plotMat.append(v)
                
                    if with_avg_total:
                        aveTotal = lines[len(lines) - 1].split()
                        classes.append('avg/total')
                        vAveTotal = [float(x) for x in t[1:len(aveTotal) - 1]]
                        plotMat.append(vAveTotal)
                
                    plt.figure(figsize=(12,48))
                    #plt.imshow(plotMat, interpolation='nearest', cmap=cmap) THIS also works but the scale is not good neither the colors for many classes(200)
                    #plt.colorbar()
                
                    plt.title(title)
                    x_tick_marks = np.arange(3)
                    y_tick_marks = np.arange(len(classes))
                    plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
                    plt.yticks(y_tick_marks, classes)
                    plt.tight_layout()
                    plt.ylabel('Classes')
                    plt.xlabel('Measures')
                    import seaborn as sns
                    sns.heatmap(plotMat, annot=True) 
                
                在此之后,确保类标签不包含由于拆分而​​导致的任何空格
                reportstr = classification_report(true_classes, y_pred,target_names=class_labels_no_spaces)
                
                plot_classification_report(reportstr)
                

                【讨论】:

                  【解决方案11】:

                  对于那些询问如何使用最新版本的classification_report(y_test, y_pred) 进行这项工作的人,您必须在此线程的accepted answer 代码中将plot_classification_report() 中的-2 方法更改为plot_classification_report()

                  我无法将此作为对答案的评论,因为我的帐户没有足够的声誉。

                  你需要改变 for line in lines[2 : (len(lines) - 2)]: for line in lines[2 : (len(lines) - 4)]:

                  或复制此编辑后的版本:

                  import matplotlib.pyplot as plt
                  import numpy as np
                  
                  def show_values(pc, fmt="%.2f", **kw):
                      '''
                      Heatmap with text in each cell with matplotlib's pyplot
                      Source: https://stackoverflow.com/a/25074150/395857 
                      By HYRY
                      '''
                      pc.update_scalarmappable()
                      ax = pc.axes
                      #ax = pc.axes# FOR LATEST MATPLOTLIB
                      #Use zip BELOW IN PYTHON 3
                      for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
                          x, y = p.vertices[:-2, :].mean(0)
                          if np.all(color[:3] > 0.5):
                              color = (0.0, 0.0, 0.0)
                          else:
                              color = (1.0, 1.0, 1.0)
                          ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
                  
                  
                  def cm2inch(*tupl):
                      '''
                      Specify figure size in centimeter in matplotlib
                      Source: https://stackoverflow.com/a/22787457/395857
                      By gns-ank
                      '''
                      inch = 2.54
                      if type(tupl[0]) == tuple:
                          return tuple(i/inch for i in tupl[0])
                      else:
                          return tuple(i/inch for i in tupl)
                  
                  
                  def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels, figure_width=40, figure_height=20, correct_orientation=False, cmap='RdBu'):
                      '''
                      Inspired by:
                      - https://stackoverflow.com/a/16124677/395857 
                      - https://stackoverflow.com/a/25074150/395857
                      '''
                  
                      # Plot it out
                      fig, ax = plt.subplots()    
                      #c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
                      c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap=cmap)
                  
                      # put the major ticks at the middle of each cell
                      ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
                      ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
                  
                      # set tick labels
                      #ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
                      ax.set_xticklabels(xticklabels, minor=False)
                      ax.set_yticklabels(yticklabels, minor=False)
                  
                      # set title and x/y labels
                      plt.title(title)
                      plt.xlabel(xlabel)
                      plt.ylabel(ylabel)      
                  
                      # Remove last blank column
                      plt.xlim( (0, AUC.shape[1]) )
                  
                      # Turn off all the ticks
                      ax = plt.gca()    
                      for t in ax.xaxis.get_major_ticks():
                          t.tick1On = False
                          t.tick2On = False
                      for t in ax.yaxis.get_major_ticks():
                          t.tick1On = False
                          t.tick2On = False
                  
                      # Add color bar
                      plt.colorbar(c)
                  
                      # Add text in each cell 
                      show_values(c)
                  
                      # Proper orientation (origin at the top left instead of bottom left)
                      if correct_orientation:
                          ax.invert_yaxis()
                          ax.xaxis.tick_top()       
                  
                      # resize 
                      fig = plt.gcf()
                      #fig.set_size_inches(cm2inch(40, 20))
                      #fig.set_size_inches(cm2inch(40*4, 20*4))
                      fig.set_size_inches(cm2inch(figure_width, figure_height))
                  
                  
                  
                  def plot_classification_report(classification_report, title='Classification report ', cmap='RdBu'):
                      '''
                      Plot scikit-learn classification report.
                      Extension based on https://stackoverflow.com/a/31689645/395857 
                      '''
                      lines = classification_report.split('\n')
                  
                      classes = []
                      plotMat = []
                      support = []
                      class_names = []
                  
                      for line in lines[2 : (len(lines) - 4)]:
                          t = line.strip().split()
                          if len(t) < 2: continue
                          classes.append(t[0])
                          v = [float(x) for x in t[1: len(t) - 1]]
                          support.append(int(t[-1]))
                          class_names.append(t[0])
                          print(v)
                          plotMat.append(v)
                  
                      print('plotMat: {0}'.format(plotMat))
                      print('support: {0}'.format(support))
                  
                      xlabel = 'Metrics'
                      ylabel = 'Classes'
                      xticklabels = ['Precision', 'Recall', 'F1-score']
                      yticklabels = ['{0} ({1})'.format(class_names[idx], sup) for idx, sup  in enumerate(support)]
                      figure_width = 25
                      figure_height = len(class_names) + 7
                      correct_orientation = False
                      heatmap(np.array(plotMat), title, xlabel, ylabel, xticklabels, yticklabels, figure_width, figure_height, correct_orientation, cmap=cmap)
                  
                  
                  def main():
                      # OLD 
                      # sampleClassificationReport = """             precision    recall  f1-score   support
                      # 
                      #       Acacia       0.62      1.00      0.76        66
                      #       Blossom       0.93      0.93      0.93        40
                      #       Camellia       0.59      0.97      0.73        67
                      #       Daisy       0.47      0.92      0.62       272
                      #       Echium       1.00      0.16      0.28       413
                      # 
                      #     avg / total       0.77      0.57      0.49       858"""
                  
                      # NEW
                      sampleClassificationReport = """              precision    recall  f1-score   support
                  
                             1       1.00      0.33      0.50         9
                             2       0.50      1.00      0.67         9
                             3       0.86      0.67      0.75         9
                             4       0.90      1.00      0.95         9
                             5       0.67      0.89      0.76         9
                             6       1.00      1.00      1.00         9
                             7       1.00      1.00      1.00         9
                             8       0.90      1.00      0.95         9
                             9       0.86      0.67      0.75         9
                            10       1.00      0.78      0.88         9
                            11       1.00      0.89      0.94         9
                            12       0.90      1.00      0.95         9
                            13       1.00      0.56      0.71         9
                            14       1.00      1.00      1.00         9
                            15       0.60      0.67      0.63         9
                            16       1.00      0.56      0.71         9
                            17       0.75      0.67      0.71         9
                            18       0.80      0.89      0.84         9
                            19       1.00      1.00      1.00         9
                            20       1.00      0.78      0.88         9
                            21       1.00      1.00      1.00         9
                            22       1.00      1.00      1.00         9
                            23       0.27      0.44      0.33         9
                            24       0.60      1.00      0.75         9
                            25       0.56      1.00      0.72         9
                            26       0.18      0.22      0.20         9
                            27       0.82      1.00      0.90         9
                            28       0.00      0.00      0.00         9
                            29       0.82      1.00      0.90         9
                            30       0.62      0.89      0.73         9
                            31       1.00      0.44      0.62         9
                            32       1.00      0.78      0.88         9
                            33       0.86      0.67      0.75         9
                            34       0.64      1.00      0.78         9
                            35       1.00      0.33      0.50         9
                            36       1.00      0.89      0.94         9
                            37       0.50      0.44      0.47         9
                            38       0.69      1.00      0.82         9
                            39       1.00      0.78      0.88         9
                            40       0.67      0.44      0.53         9
                  
                      accuracy                           0.77       360
                     macro avg       0.80      0.77      0.76       360
                  weighted avg       0.80      0.77      0.76       360
                      """
                      plot_classification_report(sampleClassificationReport)
                      plt.savefig('test_plot_classif_report.png', dpi=200, format='png', bbox_inches='tight')
                      plt.close()
                  
                  if __name__ == "__main__":
                      main()
                      #cProfile.run('main()') # if you want to do some profiling
                  

                  【讨论】:

                    猜你喜欢
                    • 2020-02-14
                    • 2017-06-29
                    • 2015-03-24
                    • 2018-01-24
                    • 2022-10-20
                    • 2018-09-26
                    • 2014-11-24
                    • 2015-07-26
                    • 2020-11-01
                    相关资源
                    最近更新 更多