【问题标题】:Multiple linear regression scikit-learn and statsmodel多元线性回归 scikit-learn 和 statsmodel
【发布时间】:2016-10-23 12:24:23
【问题描述】:

我正在尝试使用 scikit-learn 对数据集使用多元线性回归,但我无法获得正确的系数。我使用的是休伦湖数据,可以在这里找到:

https://vincentarelbundock.github.io/Rdatasets/datasets.html

在转换它之后,我有以下一组值:

         x1        x2        y
0  0.202165  1.706366  0.840567
1  1.706366  0.840567  0.694768
2  0.840567  0.694768 -0.291031
3  0.694768 -0.291031  0.333170
4 -0.291031  0.333170  0.387371
5  0.333170  0.387371  0.811572
6  0.387371  0.811572  1.415773
7  0.811572  1.415773  1.359974
8  1.415773  1.359974  1.504176
9  1.359974  1.504176  1.768377
...  ...       ...       ...

使用

df = pd.DataFrame(nvalues, columns=("x1", "x2", "y"))
result = sm.ols(formula="y ~ x2 + x1", data=df).fit()

print(result.params)

产量

Intercept   -0.007852
y2           1.002137
y1          -0.283798

这是正确的值,但如果我最终使用 scikit-learn,我会得到:

a = np.array([nvalues["x1"], nvalues["x2"]])
b = np.array(nvalues["y"])

a = a.reshape(len(nvalues["x1"]), 2)
b = b.reshape(len(nvalues["y"]), 1)

clf = linear_model.LinearRegression()
clf.fit(a, b)

print(clf.coef_)

我收到[[-0.18260922 0.08101687]]

为了完整我的代码

from sklearn import linear_model

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.formula.api as sm

def Main():
    location = r"~/Documents/Time Series/LakeHuron.csv"
    ts = pd.read_csv(location, sep=",", parse_dates=[0], header=0)

    #### initializes the data ####
    ts.drop("Unnamed: 0", axis=1, inplace=True)

    x = ts["time"].values
    y = ts["LakeHuron"].values

    x = x.reshape(len(ts), 1)
    y = y.reshape(len(ts), 1)

    regr = linear_model.LinearRegression()
    regr.fit(x, y)

    diff = []
    for i in range(0, len(ts)):
        diff.append(float(ts["LakeHuron"][i]-regr.predict(x)[i]))

    ts[3] = diff

    nvalues = {"x1": [], "x2": [], "y": []}

    for i in range(0, len(ts)-2):
        nvalues["x1"].append(float(ts[3][i]))
        nvalues["x2"].append(float(ts[3][i+1]))
        nvalues["y"].append(float(ts[3][i+2]))

    df = pd.DataFrame(nvalues, columns=("x1", "x2", "y"))
    result = sm.ols(formula="y ~ x2 + x1", data=df).fit()

    print(result.params)

    #### using scikit-learn ####
    a = np.array([nvalues["x1"], nvalues["x2"]])
    b = np.array(nvalues["y"])

    a = a.reshape(len(nvalues["x1"]), 2)
    b = b.reshape(len(nvalues["y"]), 1)

    clf = linear_model.LinearRegression()
    clf.fit(a, b)

    print(clf.coef_)

if __name__ == "__main__":
    Main()

【问题讨论】:

  • 实际上,我会说,[[-0.18260922 0.08101687]](截距 [-0.02583547]print(clf.intercept_) 给出)是正确的值。 R 给出相同的值。而且,事实上,statsmodels.api.OLS 也给出了相同的值(要尝试,请执行以下操作(将每个 ; 替换为换行符):import statsmodels.api as sm2;import statsmodels.tools.tools as smtools;a2 = smtools.add_constant(a);result2 = sm2.OLS(b,a2).fit();print(result2.params))。只是statsmodels.formula.api.ols 在上面显示了这些不同的值。
  • 我试图复制结果的书与 statsmodels.api 显示的结果相似。我想使用 scikit learn,因为我在 AR(1) 示例中取得了非常好的成功,并希望将其扩展到 AR(p) 和 MA(q) 模型,并最终扩展到 ARIMA(p,d,q) 模型
  • 你是对的。问题出在由 ndarray 的不正确 reshape() 导致的数据集上。请参阅下面的答案。

标签: python scikit-learn statsmodels


【解决方案1】:

根据@Orange 的建议,我已将代码更改为我认为更有效的代码:

#### using scikit-learn ####
a = []
for i in range(0, len(nvalues["x1"])):
    a.append([nvalues["x1"][i], nvalues["x2"][i]])

a = np.array(a)
b = np.array(nvalues["y"])

a = a.reshape(len(a), 2)
b = b.reshape(len(nvalues["y"]), 1)

clf = linear_model.LinearRegression()
clf.fit(a, b)

print(clf.coef_) 

这类似于 scikit-learn 网站上的简单回归示例

【讨论】:

    【解决方案2】:

    问题是线

    a = np.array([nvalues["x1"], nvalues["x2"]])
    

    因为它不会按照您想要的方式对数据进行排序。相反,它将生成一个数据集

    x1_new    x2_new
    -----------------
     x1[0]     x1[1]
     x1[2]     x1[3]
    [...]
     x1[94]    x1[95]
     x2[0]     x2[1]
    [...]
    

    试试吧

    ax1 = np.array(nvalues["x1"])
    ax2 = np.array(nvalues["x2"])
    ax1 = ax1.reshape(len(nvalues["x1"]), 1)
    ax2 = ax2.reshape(len(nvalues["x2"]), 1)
    a = np.hstack([ax1,ax2])
    

    可能有一种更清洁的方法可以做到这一点,但这种方式是有效的。回归现在也给出了所有正确的结果。

    编辑: 更简洁的方法是使用transpose():

    a = a.transpose()
    

    【讨论】:

      猜你喜欢
      • 2022-01-01
      • 2023-03-25
      • 2019-05-14
      • 2017-12-25
      • 2018-07-31
      • 2016-05-10
      • 2015-02-14
      • 2021-04-01
      • 2017-03-26
      相关资源
      最近更新 更多