【问题标题】:Scikit-Learn Linear Regression on Square Matrix Seems Incorrect方阵上的 Scikit-Learn 线性回归似乎不正确
【发布时间】:2016-08-22 09:38:17
【问题描述】:

我将k 响应变量y 线性回归到k x n 预测变量X,其中k >= n。使用 Scikit-Learn,回归似乎是正确的,除非 n = k;即,当预测变量矩阵为正方形时。考虑以下 MWE,其中我随机生成矩阵 X 和系数 b 来构造 y,然后使用 scikit-learn 执行回归以检查系数是否与真实系数相同:

import numpy as np
from sklearn import linear_model


n = 5  # number of predictor variables                                  
k = 5  # number of response variables, >= n *** Set = n for issue ***   
mu_b = 2.0  # mean of each component of b, used to create coeffs        

print "n = ", n

# generate true coefficients ~ N(2,0.09)                                
b = np.random.normal(2.0, 0.3, n)

print "b = ", b

# generate true predictors ~ N(0,25)                                    
X = np.random.normal(0.0, 5.0, (k,n))

# generate true target variables                                        
y = X.dot(b)


# create linear regression object                                       
regr = linear_model.LinearRegression()

# train model                                                           
regr.fit(X,y)

# print coeffs                                                          
print "estimated b = ", regr.coef_

# print difference                                                      
print "difference = ", np.linalg.norm(b - regr.coef_)

如果 k > n 则模机器精度没有差异,但是当 k = n 时可能会有很大差异。有人遇到过这种情况么?这是一个已知的错误吗?

【问题讨论】:

  • 我认为您的概念混淆了,在您的代码中,k 是预测变量的数量,n 是样本大小;并且您需要使 n 严格大于 k 才能使回归有意义。想一想,想象你k=1,所以你只有一个要预测的变量,即你要计算一条线的梯度。你至少需要两个点来计算这个,所以你必须有n>2
  • @maxymoo 不,它的代码是正确的。当我解决系统使用 np.linalg 中的 lstsq 作为方阵,以及 k >= n 使用 scikit-learn 时,它可以工作。我相信 scikit-learn 中存在错误。

标签: python numpy scikit-learn


【解决方案1】:

默认情况下,LinearRegression 类的属性fit_intercept 设置为True。这似乎有两个影响。首先,在使用linalg.lstsq 拟合模型之前,矩阵Xy 通过减去_center_data 方法中的平均值来居中。二、模型拟合好后,_set_intercept设置:

regr.intercept_ = y_mean - np.dot(X_mean, regr.coef_.T)

从文档中不清楚这个截距项是如何得出的。

在您的情况下,您可以检查在k > n 生成的intercept_ 项的顺序为1e-14 的情况下,但对于k = nintercept_ 非零,解释了为什么系数向量在k = n 的情况下不匹配。您可以通过在模型中设置 fit_intercept=False 来解决所有这些问题。

警告:当然,更有意义的答案可能会解释截距项是如何得出的,并提供有关 k > n 截距项为何为零的见解。

【讨论】:

    猜你喜欢
    • 2016-05-10
    • 2019-05-14
    • 2017-12-25
    • 2021-04-02
    • 2018-07-31
    • 2021-04-01
    • 2016-10-23
    • 2017-03-26
    • 2016-07-24
    相关资源
    最近更新 更多