编辑:
生成一些随机数据:
from scipy.cluster.vq import kmeans2
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
n_clusters = 10
df = pd.DataFrame({'x':np.random.randn(1000), 'y':np.random.randn(1000)})
_, df['cluster'] = kmeans2(df, n_clusters)
更新
- 使用
seaborn.relplot 和kind='scatter' 或使用seaborn.scatterplot
# figure level plot
sns.relplot(data=df, x='x', y='y', hue='cluster', palette='tab10', kind='scatter')
# axes level plot
fig, axes = plt.subplots(figsize=(6, 6))
sns.scatterplot(data=df, x='x', y='y', hue='cluster', palette='tab10', ax=axes)
axes.legend(loc='center left', bbox_to_anchor=(1, 0.5))
原答案
绘图(matplotlib v3.3.4):
fig, ax = plt.subplots(figsize=(8, 6))
cmap = plt.cm.get_cmap('jet')
for i, cluster in df.groupby('cluster'):
_ = ax.scatter(cluster['x'], cluster['y'], color=cmap(i/n_clusters), label=i, ec='k')
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
结果:
解释:
不要过多介绍 matplotlib 内部的细节,一次绘制一个集群可以解决问题。
更具体地说,ax.scatter() 返回一个 PathCollection 对象,我们在此处明确将其丢弃,但 似乎 在内部传递给某种图例处理程序。一次绘制只生成一个PathCollection/label 对,而一次绘制一个集群会生成n_clustersPathCollection/label 对。您可以通过调用 ax.get_legend_handles_labels() 来查看这些对象,它会返回如下内容:
([<matplotlib.collections.PathCollection at 0x7f60c2ff2ac8>,
<matplotlib.collections.PathCollection at 0x7f60c2ff9d68>,
<matplotlib.collections.PathCollection at 0x7f60c2ff9390>,
<matplotlib.collections.PathCollection at 0x7f60c2f802e8>,
<matplotlib.collections.PathCollection at 0x7f60c2f809b0>,
<matplotlib.collections.PathCollection at 0x7f60c2ff9908>,
<matplotlib.collections.PathCollection at 0x7f60c2f85668>,
<matplotlib.collections.PathCollection at 0x7f60c2f8cc88>,
<matplotlib.collections.PathCollection at 0x7f60c2f8c748>,
<matplotlib.collections.PathCollection at 0x7f60c2f92d30>],
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
所以实际上ax.legend() 等价于ax.legend(*ax.get_legend_handles_labels())。
注意事项:
-
如果使用 Python 2,请确保 i/n_clusters 是 float
-
省略 fig, ax = plt.subplots() 并改用 plt.<method>
ax.<method> 工作正常,但我总是更喜欢明确
指定我正在使用的 Axes 对象,而不是隐式使用
“当前坐标区” (plt.gca())。
旧的简单解决方案
如果您对颜色条(而不是离散值标签)没问题,您可以使用 Pandas 内置的 Matplotlib 功能:
df.plot.scatter('x', 'y', c='cluster', cmap='jet')