【问题标题】:Matplotlib scatter(): Polynomial regression line [duplicate]Matplotlib scatter():多项式回归线
【发布时间】:2018-08-27 17:14:30
【问题描述】:

是否可以在 matplotlib 中的 scatter() 上做多项式回归线?

这是我的图表: https://imgur.com/a/Xh1BO

    alg_n = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4...]
    orig_hc_runtime = [0.01, 0.02, 0.03, 0.04, 0.04, 0.04, 0.05, 0.09...]

    plt.scatter(alg_n, orig_hc_runtime, label="Orig HC", color="b", s=4)
    plt.scatter(alg_n, mod_hc_runtime, label="Mod HC", color="c", s=4)
    ...

    x_values = [x for x in range(5, n_init+2, 2)]
    y_values = [y for y in range(0, 10, 2)]

    plt.xlabel("Number of Queens")
    plt.ylabel("Time (sec)")
    plt.title("Algorithm Performance: Time")
    plt.xticks(x_values)
    plt.yticks(y_values)
    plt.grid(linewidth="1", color="white")
    plt.legend()
    plt.show()

eat 数据集是否可以有回归线?如果是这样,请您解释一下我该怎么做。

【问题讨论】:

  • 你能告诉我们你用来生成这些图的代码吗?
  • 更新示例代码

标签: python matplotlib


【解决方案1】:

我建议您使用 Seaborn 库。它建立在 matplotlib 之上,并具有许多统计绘图例程。查看regplotlmplot 的示例:http://seaborn.pydata.org/tutorial/regression.html#functions-to-draw-linear-regression-models

在您的情况下,您可以执行以下操作:

import pandas as pd
import seaborn as sns
df = pd.DataFrame.from_dict({"Number of Queens": [1, 1, 1, 2, 2, 2, 3,
                                                  3, 3, 4, 4, 4],
                             "Time (sec)": [0.01, 0.02, 0.03, 0.04, 0.04, 0.04,
                                            0.05, 0.09, 0.12, 0.14, 0.15, 0.16]})
sns.lmplot('Number of Queens', 'Time (sec)', df, order=1)

如果您想要不同组的回归线,请添加带有组标签的列并将其添加到lm_plothue 参数。

【讨论】:

  • 太棒了!谢谢
【解决方案2】:

不确定是否可以仅使用 matplotlib 完成,但您始终可以单独计算回归并绘制它。我留下一个使用 scikit-learn 计算回归线的示例代码。

import numpy as np
from sklearn.preprocessing import PolynomialFeatures
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline

x = [1, 2, 3, 4, 5, 8, 10]
y = [1.1, 3.8, 8.5, 16, 24, 65, 99.2]

model = make_pipeline(PolynomialFeatures(2), LinearRegression())
model.fit(np.array(x).reshape(-1, 1), y)
x_reg = np.arange(11)
y_reg = model.predict(x_reg.reshape(-1, 1))

plt.scatter(x, y)
plt.plot(x_reg, y_reg)
plt.show()

输出:

【讨论】:

    猜你喜欢
    • 2021-03-11
    • 2023-01-17
    • 2020-05-08
    • 2016-06-10
    • 2020-09-08
    • 2020-10-26
    • 2021-01-28
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多