【问题标题】:matplotlib scatterplot with legendmatplotlib 散点图与图例
【发布时间】:2017-06-22 17:47:24
【问题描述】:

我有兴趣在我的散点图中绘制一个图例。我当前的代码是这样的

x=[1,2,3,4]
y=[5,6,7,8]
classes = [2,4,4,2]
plt.scatter(x, y, c=classes, label=classes)
plt.legend()

问题是在创建绘图时,图例显示为数组,而不是显示唯一标签及其类。

我知道这是之前在one 等线程中讨论过的一个问题,但是我觉得我的问题更简单,并且那里的解决方案不适合它。此外,在该示例中,该人正在指定颜色,但在我的情况下,我事先知道我需要多少种颜色。此外,在this 示例中,用户正在创建多个散点图,每个散点图具有唯一的颜色。同样,这不是我想要的。我的目标是使用 x,y 数组和标签简单地创建绘图。这可能吗?

谢谢。

【问题讨论】:

    标签: python matplotlib plot


    【解决方案1】:

    实际上,这两个链接问题都提供了一种实现预期结果的方法。

    最简单的方法是创建与存在的唯一类一样多的散点图,并为每个类指定一个颜色和图例条目。

    import matplotlib.pyplot as plt
    
    x=[1,2,3,4]
    y=[5,6,7,8]
    classes = [2,4,4,2]
    unique = list(set(classes))
    colors = [plt.cm.jet(float(i)/max(unique)) for i in unique]
    for i, u in enumerate(unique):
        xi = [x[j] for j  in range(len(x)) if classes[j] == u]
        yi = [y[j] for j  in range(len(x)) if classes[j] == u]
        plt.scatter(xi, yi, c=colors[i], label=str(u))
    plt.legend()
    
    plt.show()
    

    如果类是字符串标签,解决方案看起来会略有不同,因为您需要从它们的索引中获取颜色,而不是使用类本身。

    import numpy as np
    import matplotlib.pyplot as plt
    
    x=[1,2,3,4]
    y=[5,6,7,8]
    classes = ['X','Y','Z','X']
    unique = np.unique(classes)
    colors = [plt.cm.jet(i/float(len(unique)-1)) for i in range(len(unique))]
    for i, u in enumerate(unique):
        xi = [x[j] for j  in range(len(x)) if classes[j] == u]
        yi = [y[j] for j  in range(len(x)) if classes[j] == u]
        plt.scatter(xi, yi, c=colors[i], label=str(u))
    plt.legend()
    
    plt.show()
    

    【讨论】:

    • 您的回答帮助我创建了我想要的情节。谢谢。
    • 只有我一个人对没有内置方法感到惊讶吗?我觉得绘制不同类别​​的点分布是一项非常常见的任务。如果您知道原因,请赐教。
    • @Johannes 您到底想要什么“内置”?这是剧情还是传说?
    • 为散点图构建图例,简单地显示哪种颜色对应于哪个标签。我发现我必须遍历标签感到惊讶。一个(imo)非常常见的用例是机器学习:您对特征进行了精美的降维处理,并希望将它们绘制为 2D。你得到的是许多属于有限类集的点(-> 你的标签/颜色)。例如,谷歌搜索“t-SNE 图”。
    • @Johannes 您不必遍历类。例如,这里的另一个答案根本没有循环。我不知道使用 matplotlib 创建的所有绘图的比例会落在机器学习的范围内,但 matplotlib 肯定不是专门针对任何领域的;它只是提供了创建数据图的方法。它的核心代码应该比机器学习炒作更早。在 matplotlib 中创建自定义图例非常简单,因此如果您需要重复执行此操作,您可以编写自己的 scatterlegend(sc, values, classes=None) 函数。如果您需要帮助,何不提出一个新问题。
    【解决方案2】:

    在这里手动填写table 可能会很有用。如果你的类是连续的数字,另一个想法是使用colorbar。我将两种方法合二为一。

    import matplotlib.pyplot as plt
    import numpy as np
    
    x=[1,2,3,4,5,6,7]
    y=[1,2,3,4,5,6,7]
    classes = [2,4,4,2,1,3,5]
    cmap = plt.cm.get_cmap("viridis",5)
    plt.scatter(x, y, c=classes, label=classes,cmap=cmap,vmin=0.5,vmax=5.5)
    plt.colorbar()
    unique_classes = list(set(classes))
    plt.table(cellText=[[x] for x in unique_classes], loc='lower right',
              colWidths=[0.2],rowColours=cmap(np.array(unique_classes)-1),
             rowLabels=['label%d'%x for x in unique_classes],
              colLabels=['classes'])
    

    【讨论】:

    • 好主意!只有颜色映射有点hacky。
    • 是的,它需要一些调整,并且仅当您的数据包含区间内的所有数字时才适用。