【问题标题】:matplotlib does not show legend in scatter plotmatplotlib 在散点图中不显示图例
【发布时间】:2016-12-29 17:05:34
【问题描述】:

我正在尝试解决一个聚类问题,我需要为我的聚类绘制散点图。

%matplotlib inline
import matplotlib.pyplot as plt
df = pd.merge(dataframe,actual_cluster)
plt.scatter(df['x'], df['y'], c=df['cluster'])
plt.legend()
plt.show()

df['cluster'] 是实际的簇号。所以我希望它成为我的颜色代码。

它向我展示了一个情节,但它没有向我展示传说。它也不会给我错误。

我做错了吗?

【问题讨论】:

    标签: python matplotlib plot cluster-analysis


    【解决方案1】:

    编辑:

    生成一些随机数据:

    from scipy.cluster.vq import kmeans2
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    n_clusters = 10
    df = pd.DataFrame({'x':np.random.randn(1000), 'y':np.random.randn(1000)})
    _, df['cluster'] = kmeans2(df, n_clusters)
    

    更新

    • 使用seaborn.relplotkind='scatter' 或使用seaborn.scatterplot
      • 指定hue='cluster'
    # figure level plot
    sns.relplot(data=df, x='x', y='y', hue='cluster', palette='tab10', kind='scatter')
    

    # axes level plot
    fig, axes = plt.subplots(figsize=(6, 6))
    sns.scatterplot(data=df, x='x', y='y', hue='cluster', palette='tab10', ax=axes)
    axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    

    原答案

    绘图(matplotlib v3.3.4):

    fig, ax = plt.subplots(figsize=(8, 6))
    cmap = plt.cm.get_cmap('jet')
    for i, cluster in df.groupby('cluster'):
        _ = ax.scatter(cluster['x'], cluster['y'], color=cmap(i/n_clusters), label=i, ec='k')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    

    结果:

    解释:

    不要过多介绍 matplotlib 内部的细节,一次绘制一个集群可以解决问题。 更具体地说,ax.scatter() 返回一个 PathCollection 对象,我们在此处明确将其丢弃,但 似乎 在内部传递给某种图例处理程序。一次绘制只生成一个PathCollection/label 对,而一次绘制一个集群会生成n_clustersPathCollection/label 对。您可以通过调用 ax.get_legend_handles_labels() 来查看这些对象,它会返回如下内容:

    ([<matplotlib.collections.PathCollection at 0x7f60c2ff2ac8>,
      <matplotlib.collections.PathCollection at 0x7f60c2ff9d68>,
      <matplotlib.collections.PathCollection at 0x7f60c2ff9390>,
      <matplotlib.collections.PathCollection at 0x7f60c2f802e8>,
      <matplotlib.collections.PathCollection at 0x7f60c2f809b0>,
      <matplotlib.collections.PathCollection at 0x7f60c2ff9908>,
      <matplotlib.collections.PathCollection at 0x7f60c2f85668>,
      <matplotlib.collections.PathCollection at 0x7f60c2f8cc88>,
      <matplotlib.collections.PathCollection at 0x7f60c2f8c748>,
      <matplotlib.collections.PathCollection at 0x7f60c2f92d30>],
     ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
    

    所以实际上ax.legend() 等价于ax.legend(*ax.get_legend_handles_labels())

    注意事项:

    1. 如果使用 Python 2,请确保 i/n_clustersfloat

    2. 省略 fig, ax = plt.subplots() 并改用 plt.&lt;method&gt; ax.&lt;method&gt; 工作正常,但我总是更喜欢明确 指定我正在使用的 Axes 对象,而不是隐式使用 “当前坐标区” (plt.gca())。


    旧的简单解决方案

    如果您对颜色条(而不是离散值标签)没问题,您可以使用 Pandas 内置的 Matplotlib 功能:

    df.plot.scatter('x', 'y', c='cluster', cmap='jet')
    

    【讨论】:

      【解决方案2】:

      这是一个困扰我很久的问题。现在,我想提供另一个简单的解决方案。我们不必编写任何循环!!!

      def vis(ax, df, label, title="visualization"):
          points = ax.scatter(df[:, 0], df[:, 1], c=label, label=label, alpha=0.7)
          ax.set_title(title)
          ax.legend(*points.legend_elements(), title="Classes")
      

      【讨论】:

      • 简单而完美的解决方案。