【问题标题】:How to add group labels for bar charts in matplotlib如何在 matplotlib 中为条形图添加组标签
【发布时间】:2013-10-11 15:27:56
【问题描述】:

我想绘制以下形式的数据,使用 matplotlib 条形图:

data = {'Room A':
           {'Shelf 1':
               {'Milk': 10,
                'Water': 20},
            'Shelf 2':
               {'Sugar': 5,
                'Honey': 6}
           },
        'Room B':
           {'Shelf 1':
               {'Wheat': 4,
                'Corn': 7},
            'Shelf 2':
               {'Chicken': 2,
                'Cow': 1}
           }
       }

条形图应该是这样的

从 x 轴上的标签应该可以看到条形组。有没有办法用 matplotlib 做到这一点?

【问题讨论】:

    标签: python matplotlib bar-chart


    【解决方案1】:

    由于我在 matplotlib 中找不到内置解决方案,因此我编写了自己的代码:

    #!/usr/bin/env python
    
    from matplotlib import pyplot as plt
    
    def mk_groups(data):
        try:
            newdata = data.items()
        except:
            return
    
        thisgroup = []
        groups = []
        for key, value in newdata:
            newgroups = mk_groups(value)
            if newgroups is None:
                thisgroup.append((key, value))
            else:
                thisgroup.append((key, len(newgroups[-1])))
                if groups:
                    groups = [g + n for n, g in zip(newgroups, groups)]
                else:
                    groups = newgroups
        return [thisgroup] + groups
    
    def add_line(ax, xpos, ypos):
        line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                          transform=ax.transAxes, color='black')
        line.set_clip_on(False)
        ax.add_line(line)
    
    def label_group_bar(ax, data):
        groups = mk_groups(data)
        xy = groups.pop()
        x, y = zip(*xy)
        ly = len(y)
        xticks = range(1, ly + 1)
    
        ax.bar(xticks, y, align='center')
        ax.set_xticks(xticks)
        ax.set_xticklabels(x)
        ax.set_xlim(.5, ly + .5)
        ax.yaxis.grid(True)
    
        scale = 1. / ly
        for pos in xrange(ly + 1):  # change xrange to range for python3
            add_line(ax, pos * scale, -.1)
        ypos = -.2
        while groups:
            group = groups.pop()
            pos = 0
            for label, rpos in group:
                lxpos = (pos + .5 * rpos) * scale
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
                add_line(ax, pos * scale, ypos)
                pos += rpos
            add_line(ax, pos * scale, ypos)
            ypos -= .1
    
    if __name__ == '__main__':
        data = {'Room A':
                   {'Shelf 1':
                       {'Milk': 10,
                        'Water': 20},
                    'Shelf 2':
                       {'Sugar': 5,
                        'Honey': 6}
                   },
                'Room B':
                   {'Shelf 1':
                       {'Wheat': 4,
                        'Corn': 7},
                    'Shelf 2':
                       {'Chicken': 2,
                        'Cow': 1}
                   }
               }
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        label_group_bar(ax, data)
        fig.subplots_adjust(bottom=0.3)
        fig.savefig('label_group_bar_example.png')
    

    mk_groups 函数接受一个字典(或任何带有 items() 方法的东西,如collections.OrderedDict)并将其转换为数据格式,然后用于创建图表。它基本上是一个列表形式:

    [ [(label, bars_to_span), ...], ..., [(tick_label, bar_value), ...] ]
    

    add_line 函数在子图中的指定位置(在坐标轴坐标中)创建一条垂直线。

    label_group_bar 函数采用字典并在子图中创建条形图,其下方带有标签。该示例的结果看起来像like this

    仍然非常感谢更简单或更好的解决方案和建议。

    【讨论】:

    • 如果您使用的是 Python 3,xrange 已重命名为 range。所以使用 range 而不是 xrange。
    【解决方案2】:

    我一直在寻找这个解决方案。 我对其进行了一些修改以使用熊猫数据表。 只有公平分享。

    import pandas as pd
    import numpy as np
    from matplotlib import pyplot as plt
    from itertools import groupby
    
    def test_table():
        data_table = pd.DataFrame({'Room':['Room A']*4 + ['Room B']*4,
                                   'Shelf':(['Shelf 1']*2 + ['Shelf 2']*2)*2,
                                   'Staple':['Milk','Water','Sugar','Honey','Wheat','Corn','Chicken','Cow'],
                                   'Quantity':[10,20,5,6,4,7,2,1],
                                   'Ordered':np.random.randint(0,10,8)
                                   })
        return data_table
    
    def add_line(ax, xpos, ypos):
        line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
                          transform=ax.transAxes, color='black')
        line.set_clip_on(False)
        ax.add_line(line)
    
    def label_len(my_index,level):
        labels = my_index.get_level_values(level)
        return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
        
    def label_group_bar_table(ax, df):
        ypos = -.1
        scale = 1./df.index.size
        for level in range(df.index.nlevels)[::-1]:
            pos = 0
            for label, rpos in label_len(df.index,level):
                lxpos = (pos + .5 * rpos)*scale
                ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes)
                add_line(ax, pos*scale, ypos)
                pos += rpos
            add_line(ax, pos*scale , ypos)
            ypos -= .1
    
    df = test_table().groupby(['Room','Shelf','Staple']).sum()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    df.plot(kind='bar',stacked=True,ax=fig.gca())
    #Below 3 lines remove default labels
    labels = ['' for item in ax.get_xticklabels()]
    ax.set_xticklabels(labels)
    ax.set_xlabel('')
    label_group_bar_table(ax, df)
    fig.subplots_adjust(bottom=.1*df.index.nlevels)
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2015-05-09
      相关资源
      最近更新 更多