【问题标题】:How to to graph multiple lines using sns.scatterplot如何使用 sns.scatterplot 绘制多条线
【发布时间】:2022-11-12 15:25:00
【问题描述】:

我写了一个这样的程序:

# Author: Evan Gertis
# Date  : 11/09
# program: Linear Regression
# Resource: https://seaborn.pydata.org/generated/seaborn.scatterplot.html       
import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Step 1: load the data
grades = pd.read_csv("grades.csv") 
logging.info(grades.head())

# Step 2: plot the data
plot = sns.scatterplot(data=grades, x="Hours", y="GPA")
fig = plot.get_figure()
fig.savefig("out.png")

使用数据集

Hours,GPA,Hours,GPA,Hours,GPA
11,2.84,9,2.85,25,1.85
5,3.20,5,3.35,6,3.14
22,2.18,14,2.60,9,2.96
23,2.12,18,2.35,20,2.30
20,2.55,6,3.14,14,2.66
20,2.24,9,3.05,19,2.36
10,2.90,24,2.06,21,2.24
19,2.36,25,2.00,7,3.08
15,2.60,12,2.78,11,2.84
18,2.42,6,2.90,20,2.45

我想绘制出所有的关系,此时我只得到一个图:

预期的: 绘制的所有关系

实际的:

对此的任何帮助将不胜感激。谢谢!

我写了一个基本程序,我期待所有的关系都被绘制出来。

【问题讨论】:

  • 只是一个简单的问题:是否有理由不将数据放在 2 列中?像这样重构它可以解决问题,但它可能不是您想要的?

标签: python seaborn


【解决方案1】:

问题的根源是文件中的列名相同,因此当熊猫读取列时,会向加载的数据框添加编号

import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

grades = pd.read_csv("grades.csv") 
print(grades.columns)
>>> Index(['Hours', 'GPA', 'Hours.1', 'GPA.1', 'Hours.2', 'GPA.2'], dtype='object')

因此,当您绘制散点图时,您需要给出熊猫给出的列名

# in case you want all scatter plots in the same figure
plot = sns.scatterplot(data=grades, x="Hours", y="GPA", label='GPA')
sns.scatterplot(data=grades, x='Hours.1', y='GPA.1', ax=plot, label="GPA.1")
sns.scatterplot(data=grades, x='Hours.2', y='GPA.2', ax=plot,  label='GPA.2')
fig = plot.get_figure()
fig.savefig("out.png")

【讨论】:

  • 你知道如何在图例上标记它们吗?谢谢你。
  • 更多信息可以在相关问题here 中找到,关于实现此行为的参数:mangle_dupe_cols
  • @EvanGertis我已经用参数标签更新了答案,这是获得传奇的一种选择
【解决方案2】:
  • 有比为每组列手动创建图更好的选择
  • 由于文件中的列有多余的名称,pandas 会自动重命名它们。

导入和数据框

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# read the data from the file
df = pd.read_csv('d:/data/gpa.csv')

# display(df)
   Hours   GPA  Hours.1  GPA.1  Hours.2  GPA.2
0     11  2.84        9   2.85       25   1.85
1      5  3.20        5   3.35        6   3.14
2     22  2.18       14   2.60        9   2.96
3     23  2.12       18   2.35       20   2.30
4     20  2.55        6   3.14       14   2.66
5     20  2.24        9   3.05       19   2.36
6     10  2.90       24   2.06       21   2.24
7     19  2.36       25   2.00        7   3.08
8     15  2.60       12   2.78       11   2.84
9     18  2.42        6   2.90       20   2.45

选项 1:分块列名

  • 此选项可用于在循环中绘制数据,而无需手动创建每个图
  • 使用来自How to iterate over a list in chunksanswer 将创建列名组列表:
    • [Index(['Hours', 'GPA'], dtype='object'), Index(['Hours.1', 'GPA.1'], dtype='object'), Index(['Hours.2', 'GPA.2'], dtype='object')]
# create groups of column names to be plotted together
def chunker(seq, size):
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]


# function call
col_list = chunker(df.columns, 2)

# iterate through each group of column names to plot
for x, y in chunker(df.columns, 2):
    sns.scatterplot(data=df, x=x, y=y, label=y)

选项 2:修复数据

# filter each group of columns
h = df.filter(like='Hours').values.ravel()
g = df.filter(like='GPA').values.ravel()

# get the gpa column names
gpa_cols = df.columns[1::2]

# use numpy to create a list of labels with the appropriate length
labels = np.repeat(gpa_cols, len(df))

# create a new dataframe
dfl = pd.DataFrame({'hours': h, 'gpa': g, 'label': labels})

# save dfl if desired
dfl.to_csv('gpa_long.csv', index=False)

# plot
sns.scatterplot(data=df, x='hours', y='gpa', hue='label')

绘图结果

【讨论】:

    猜你喜欢
    • 2021-04-26
    • 2021-09-14
    • 1970-01-01
    • 2017-06-13
    • 2021-06-06
    • 2019-12-01
    • 2021-04-17
    • 2018-12-17
    • 1970-01-01
    相关资源
    最近更新 更多