【问题标题】:How to get the numerical fitting results when plotting a regression in seaborn?在seaborn中绘制回归时如何获得数值拟合结果?
【发布时间】:2014-05-16 03:27:32
【问题描述】:

如果我使用 Python 中的 seaborn 库来绘制线性回归的结果,有没有办法找出回归的数值结果?例如,我可能想知道拟合系数或拟合的 R2

我可以使用底层 statsmodels 接口重新运行相同的拟合,但这似乎是不必要的重复工作,无论如何我希望能够比较结果系数以确保数值结果相同正如我在剧情中看到的那样。

【问题讨论】:

  • 有谁知道您是否可以在 seaborn 的条形图上绘制每个条形的实际值,而不是通过查看 Y 轴并尝试匹配它来猜测值?在我看到的 Seaborn 可视化库的所有示例中,没有人将实际值放在各个条上以显示确切值,它们都是完全可视的。
  • 请参阅here 可能重复的问题中的解决方案。就像在那个答案的 cmets 中一样,一个人可以很容易地get the equation line with two points 然后plot it

标签: python seaborn


【解决方案1】:

没有办法做到这一点。

在我看来,要求可视化库为您提供统计建模结果是倒退的。 statsmodels 是一个建模库,可让您拟合模型,然后绘制与您拟合的模型完全对应的图。如果您想要精确的对应关系,那么这种操作顺序对我来说更有意义。

您可能会说“但statsmodels 中的情节没有seaborn 那么多的美学选择”。但我认为这是有道理的——statsmodels 是一个建模库,有时在建模服务中使用可视化。 seaborn 是一个可视化库,有时在可视化服务中使用建模。专精是好的,什么都做不好。

幸运的是,seabornstatsmodels 都使用 tidy data。这意味着您真的只需要很少的重复工作即可通过适当的工具获得绘图和模型。

【讨论】:

  • @user333700,同意。由于这个限制,我目前没有使用 seaborn,尽管我可能会看一下。如果现在没有办法,我可能会建议一个功能,其中来自 statsmodels 的 fit 对象可以用作适当 seaborn 绘图函数的输入。
  • @mwaskom,我刚刚收到通知,这个问题已获得 2500 次浏览。只是一个数据点,以防您想知道有多少人在寻找此功能。
  • @user333700 为什么要运行两次回归? Seaborn 已经在为您驾驶汽车,只是忘记告诉您它停在哪里。它只是向您发送快照并祝您找到它好运
  • 仍然相关。我相信 seaborn 的回归,但由于我无法检查使用的参数,所以没什么意义......很高兴知道最好自己做。少用一个库....
  • 即使对于可视化包来说,这似乎也是一项基本要求。在大多数情况下,在不报告 p 值、r^2 值和系数的情况下呈现数字是不可接受的。我不会认为这是一个专门的功能。正如其他人在 cmets 中提到的那样,它确实使 seaborn 回归无法用于任何合法目的,例如研究文章。
【解决方案2】:

Seaborn 的创建者has unfortunately stated 他不会添加这样的功能。下面是一些选项。 (最后一部分包含我最初的建议,这是一个使用 seaborn 的私有实现细节的 hack,并不是特别灵活。)

regplot 的简单替代版本

以下函数在散点图上覆盖一条拟合线并返回来自statsmodels 的结果。这支持sns.regplot 的最简单也可能是最常见的用法,但没有实现任何更高级的功能。

import statsmodels.api as sm


def simple_regplot(
    x, y, n_std=2, n_pts=100, ax=None, scatter_kws=None, line_kws=None, ci_kws=None
):
    """ Draw a regression line with error interval. """
    ax = plt.gca() if ax is None else ax

    # calculate best-fit line and interval
    x_fit = sm.add_constant(x)
    fit_results = sm.OLS(y, x_fit).fit()

    eval_x = sm.add_constant(np.linspace(np.min(x), np.max(x), n_pts))
    pred = fit_results.get_prediction(eval_x)

    # draw the fit line and error interval
    ci_kws = {} if ci_kws is None else ci_kws
    ax.fill_between(
        eval_x[:, 1],
        pred.predicted_mean - n_std * pred.se_mean,
        pred.predicted_mean + n_std * pred.se_mean,
        alpha=0.5,
        **ci_kws,
    )
    line_kws = {} if line_kws is None else line_kws
    h = ax.plot(eval_x[:, 1], pred.predicted_mean, **line_kws)

    # draw the scatterplot
    scatter_kws = {} if scatter_kws is None else scatter_kws
    ax.scatter(x, y, c=h[0].get_color(), **scatter_kws)

    return fit_results

statsmodels 的结果包含大量信息,例如

>>> print(fit_results.summary())

                            OLS Regression Results                            
==============================================================================
Dep. Variable:                      y   R-squared:                       0.477
Model:                            OLS   Adj. R-squared:                  0.471
Method:                 Least Squares   F-statistic:                     89.23
Date:                Fri, 08 Jan 2021   Prob (F-statistic):           1.93e-15
Time:                        17:56:00   Log-Likelihood:                -137.94
No. Observations:                 100   AIC:                             279.9
Df Residuals:                      98   BIC:                             285.1
Df Model:                           1                                         
Covariance Type:            nonrobust                                         
==============================================================================
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         -0.1417      0.193     -0.735      0.464      -0.524       0.241
x1             3.1456      0.333      9.446      0.000       2.485       3.806
==============================================================================
Omnibus:                        2.200   Durbin-Watson:                   1.777
Prob(Omnibus):                  0.333   Jarque-Bera (JB):                1.518
Skew:                          -0.002   Prob(JB):                        0.468
Kurtosis:                       2.396   Cond. No.                         4.35
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

sns.regplot 的直接替换(几乎)

上面的方法比我下面的原始答案的优点是它很容易扩展到更复杂的配合。

无耻插件:这是我编写的这样一个扩展的regplot 函数,它实现了sns.regplot 的大部分功能:https://github.com/ttesileanu/pydove

虽然还缺少一些功能,但我写的功能

  • 通过将绘图与统计建模分开来实现灵活性(您还可以轻松访问拟合结果)。
  • 对于大型数据集来说要快得多,因为它允许statsmodels 计算置信区间而不是使用自举。
  • 允许稍微多样化的拟合(例如,log(x) 中的多项式)。
  • 允许更细粒度的绘图选项。

旧答案

Seaborn 的创建者 has unfortunately stated 他不会添加这样的功能,所以这里有一个解决方法。

def regplot(
    *args,
    line_kws=None,
    marker=None,
    scatter_kws=None,
    **kwargs
):
    # this is the class that `sns.regplot` uses
    plotter = sns.regression._RegressionPlotter(*args, **kwargs)

    # this is essentially the code from `sns.regplot`
    ax = kwargs.get("ax", None)
    if ax is None:
        ax = plt.gca()

    scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
    scatter_kws["marker"] = marker
    line_kws = {} if line_kws is None else copy.copy(line_kws)

    plotter.plot(ax, scatter_kws, line_kws)

    # unfortunately the regression results aren't stored, so we rerun
    grid, yhat, err_bands = plotter.fit_regression(plt.gca())

    # also unfortunately, this doesn't return the parameters, so we infer them
    slope = (yhat[-1] - yhat[0]) / (grid[-1] - grid[0])
    intercept = yhat[0] - slope * grid[0]
    return slope, intercept

请注意,这仅适用于线性回归,因为它只是从回归结果中推断出斜率和截距。好消息是它使用seaborn 自己的回归类,因此可以保证结果与显示的一致。当然,缺点是我们在 seaborn 中使用了一个私有实现细节,它可能随时中断。

【讨论】:

  • 自从这个答案可以追溯到 1 月 15 日以来,可能是一个长镜头,但是我尝试使用上面的代码,但我收到以下错误:local variable 'scatter_kws' referenced before assignment - 我该如何解决?
  • 原来我在def 中遗漏了一些关键字参数。现在应该可以工作了,感谢您指出这一点,@Marioanzas!
  • 谢谢,这是您在这里提供的一个非常好的功能!一个小的改进使 alpha 值也可以改变:if 'alpha' in ci_kws: alpha = ci_kws['alpha'] del ci_kws['alpha'] else: alpha= 0.5
  • @Exi 当然,我只是想在答案本身中包含一个简短的概念证明。我在 github.com/ttesileanu/pygrutils 的 repo 中的函数有更多的功能,以及对 seaborn 更好的兼容性。
【解决方案3】:

查看当前可用的文档,我能够确定现在是否可以满足此功能的最接近的方法是使用 scipy.stats.pearsonr 模块。

r2 = stats.pearsonr("pct", "rdiff", df)

在尝试使其直接在 Pandas 数据帧中工作时,由于违反了基本的 scipy 输入要求而引发了一个错误:

TypeError: pearsonr() takes exactly 2 arguments (3 given)

我设法找到了另一个显然解决了问题的 Pandas Seaborn 用户 它:https://github.com/scipy/scipy/blob/v0.14.0/scipy/stats/stats.py#L2392

sns.regplot("rdiff", "pct", df, corr_func=stats.pearsonr);

但是,不幸的是,我没有设法让它工作,因为似乎作者创建了自己的自定义“corr_func”,或者有一个未记录的 Seaborn 争论传递方法,可以使用更手动的方法:

# x and y should have same length.
    x = np.asarray(x)
    y = np.asarray(y)
    n = len(x)
    mx = x.mean()
    my = y.mean()
    xm, ym = x-mx, y-my
    r_num = np.add.reduce(xm * ym)
    r_den = np.sqrt(ss(xm) * ss(ym))
    r = r_num / r_den

# Presumably, if abs(r) > 1, then it is only some small artifact of floating
# point arithmetic.
r = max(min(r, 1.0), -1.0)
df = n-2
if abs(r) == 1.0:
    prob = 0.0
else:
    t_squared = r*r * (df / ((1.0 - r) * (1.0 + r)))
    prob = betai(0.5*df, 0.5, df / (df + t_squared))
return r, prob

希望这有助于将这个原始请求推进到一个临时解决方案,因为非常需要实用程序将回归适应度统计数据添加到 Seaborn 包中,以替代人们可以从 MS-Excel 或库存 Matplotlib 线图轻松获得的东西。

【讨论】:

  • 谢谢,有顺序依赖吗?例如,此调用绘制 scatter + linreg 线: sns.lmplot("total_bill", "tip", tips);这个添加了双变量分布+ pearsonsr:sns.jointplot("total_bill", "tip", tips);但没有 linreg 线。是否可以手动将 linreg 添加到此?:sns.lmplot("total_bill", "tip", tips, scatter_kws={"marker": ".", "color": "slategray"}, line_kws={ "linewidth": 1, "color": "seagreen"});
  • 为什么开发者不想包含这些基本信息?我不断看到类似“这很简单,只需使用其他 10 行代码”之类的建议。但这感觉不是很pythonic(尤其是复制已经完成的拟合)。为什么我要使用 Seaborn 而不是只使用 scipy 和 matplotlib 进行拟合,因为我基本上可以保证有足够的时间使用方程?
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-04-10
  • 1970-01-01
  • 2019-12-06
  • 1970-01-01
相关资源
最近更新 更多