【问题标题】:Scatterplot in matplotlib with legend and randomized point ordermatplotlib 中带有图例和随机点序的散点图
【发布时间】:2026-01-01 10:00:01
【问题描述】:

我正在尝试在 python/matplotlib 中构建来自多个类的大量数据的散点图。不幸的是,我似乎必须在随机化数据和使用图例标签之间做出选择。有没有一种方法我可以两者兼得(最好不用手动编码标签?)

最小可重现示例:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
X = np.random.normal(0, 1, [5000, 2])
Y = np.random.normal(0.5, 1, [5000, 2]) 
data = np.concatenate([X,Y])
classes = np.concatenate([np.repeat('X', X.shape[0]),
                          np.repeat('Y', Y.shape[0])])

用随机点绘图:

plot_idx = np.random.permutation(data.shape[0])
colors = pd.factorize(classes)
fig, ax = plt.subplots()
ax.scatter(data[plot_idx, 0], 
           data[plot_idx, 1], 
           c=colors[plot_idx],
           label=classes[plot_idx],
           alpha=0.4)
plt.legend()
plt.show()

这给了我错误的传说。

使用正确的图例绘图:

from matplotlib import cm
unique_classes = np.unique(classes)
colors = cm.Set1(np.linspace(0, 1, len(unique_classes)))
for i, class in enumerate(unique_classes):
    ax.scatter(data[classes == class, 0], 
               data[classes == class, 1],
               c=colors[i],
               label=class,
               alpha=0.4)
plt.legend()
plt.show()

但现在点不是随机的,结果图不代表数据。

我正在寻找可以给我一个结果的东西,就像我在 R 中得到的那样:

library(ggplot2)
X <- matrix(rnorm(10000, 0, 1), ncol=2)
Y <- matrix(rnorm(10000, 0.5, 1), ncol=2)
data <- as.data.frame(rbind(X, Y))
data$classes <- rep(c('X', 'Y'), times=nrow(X))
plot_idx <- sample(nrow(data))

ggplot(data[plot_idx,], aes(x=V1, y=V2, color=classes)) +
  geom_point(alpha=0.4, size=3)

【问题讨论】:

    标签: python matplotlib


    【解决方案1】:

    您需要手动创建图例。不过,这不是什么大问题。您可以遍历标签并为每个标签创建一个图例条目。这里可以使用Line2D,其标记类似于散点图作为句柄。

    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    X = np.random.normal(0, 1, [5000, 2])
    Y = np.random.normal(0.5, 1, [5000, 2]) 
    data = np.concatenate([X,Y])
    classes = np.concatenate([np.repeat('X', X.shape[0]),
                              np.repeat('Y', Y.shape[0])])
    
    plot_idx = np.random.permutation(data.shape[0])
    colors,labels = pd.factorize(classes)
    
    fig, ax = plt.subplots()
    sc = ax.scatter(data[plot_idx, 0], 
               data[plot_idx, 1], 
               c=colors[plot_idx],
               alpha=0.4)
    
    h = lambda c: plt.Line2D([],[],color=c, ls="",marker="o")
    plt.legend(handles=[h(sc.cmap(sc.norm(i))) for i in range(len(labels))],
               labels=list(labels))
    plt.show()
    

    或者,您可以使用特殊的分散处理程序,如问题Why doesn't the color of the points in a scatter plot match the color of the points in the corresponding legend? 中所示,但这似乎有点过分。

    【讨论】:

      【解决方案2】:

      这有点小技巧,但您可以保存坐标轴范围,通过在绘图范围之外绘制点来设置标签,然后按如下方式重置坐标轴范围:

      plot_idx = np.random.permutation(data.shape[0])
      color_idx, unique_classes = pd.factorize(classes)
      colors = cm.Set1(np.linspace(0, 1, len(unique_classes)))
      fig, ax = plt.subplots()
      ax.scatter(data[plot_idx, 0], 
                 data[plot_idx, 1], 
                 c=colors[color_idx[plot_idx]],
                 alpha=0.4)
      xlim = ax.get_xlim()
      ylim = ax.get_ylim()
      for i in range(len(unique_classes)):
          ax.scatter(xlim[1]*10, 
                     ylim[1]*10, 
                     c=colors[i], 
                     label=unique_classes[i])
      ax.set_xlim(xlim)
      ax.set_ylim(ylim)
      plt.legend()
      plt.show()
      

      【讨论】:

        最近更新 更多