【问题标题】:How to scatter plot with 1 million points如何用 100 万个点散点图
【发布时间】:2021-12-14 08:12:14
【问题描述】:

我正在尝试制作一个程序,该程序使用 csv 文件中的给定点绘制图形,该文件包含每行 4 个字符串(点数、x pos、y pos、颜色),但所需时间是高得离谱,所以我正在寻找使其更快的想法。

from matplotlib import pyplot as plt    
from matplotlib import style   
import csv

style.use('ggplot')

s = 0.5
with open('total.csv') as f:
  f_reader = csv.reader(f, delimiter=',')
  for row in f_reader:
    plt.scatter(str(row[1]), str(row[2]), color=str(row[3]), s=s)
plt.savefig("graph.png", dpi=1000)

【问题讨论】:

  • 如果您只需要了解数据中的趋势,请随机对点进行二次采样,并且仅显示 1k。无论哪种方式,读取所有数据并且只调用 scatter 一次(或者如果您对某些点有特定颜色,则每种颜色调用一次)将比为每个点调用 scatter 快得多

标签: python matplotlib graph scatter


【解决方案1】:

第一步是调用scatter 一次而不是每个点,不添加对 numpy 和 pandas 的依赖,它可能看起来像:

from matplotlib import pyplot as plt
from matplotlib import style
import csv

style.use("ggplot")

s = 0.5
x = []
y = []
c = []
with open("total.csv") as f:
    f_reader = csv.reader(f, delimiter=",")
    for row in f_reader:
        x.append(row[1])
        y.append(row[2])
        c.append(row[3])
plt.scatter(x, y, color=c, s=s)
plt.savefig("graph.png", dpi=1000)

然后也许试试pandas.read_csv,它会给你一个pandas数据框,让你可以在没有for循环的情况下访问CSV的列,这可能会更快。

每次尝试变体时,测量它所花费的时间(可能在较小的文件上)以了解哪些有帮助,哪些没有帮助,换句话说,不要盲目地尝试增强性能。

使用 pandas 看起来像:

from matplotlib import pyplot as plt
from matplotlib import style
import pandas as pd

style.use("ggplot")

total = pd.read_csv("total.csv")
plt.scatter(total.x, total.y, color=total.color, s=0.5)
plt.savefig("graph.png", dpi=1000)

如果您想了解更多关于 pandas 性能方面的良好做法,我喜欢 No more sad pandas 讲座,看看吧。

【讨论】: