【问题标题】:sklean Standard Scaler Ridge Pipelinesklearn Standardscaler Ridge Pipeline
【发布时间】:2022-01-12 01:46:19
【问题描述】:

我正在尝试标准化特征,然后运行岭回归。

正如所提供的,这两个答案是不同的。

当我设置 ridge=0 时,答案是一样的。当我删除 StandardScaler 和 Dn 时,答案也是一样的。

我不知道如何协调这两个版本(原始版本和使用 sklearn)。

感谢您的帮助

from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np
np.random.seed(0)

x = np.random.randn(100, 3)
y = np.random.randn(100, 2)

xx = x.T @ x
xy = x.T @ y
Dn = np.diag(1 / np.sqrt(np.diag(xx)))

ridge = 1

xx = Dn @ xx @ Dn
xy = Dn @ xy
beta_raw = Dn @ np.linalg.solve(xx + np.eye(len(xx)) * ridge, xy)
f_raw = x @ beta_raw

model = Pipeline([("scaler", StandardScaler(with_mean=False)), ("regression", Ridge(ridge, fit_intercept=False))])
trained_model = model.fit(x, y)
f_ml = trained_model.predict(x)

print(f_ml[:3] / f_raw[:3])

【问题讨论】:

    标签: python numpy scikit-learn


    【解决方案1】:

    您正在按不同的值进行缩放,检查:

    np.diag(Dn)
    array([0.09699826, 0.10123938, 0.1016412 ])
    
    model.steps[0][1].scale_
    array([1.02603414, 0.98202661, 0.97598415])
    

    您的标准差是协方差矩阵对角线的平方。即使您没有使矩阵居中,您仍然需要减去均值以获得协方差。见this post for more information

    所以如果我们做对了:

    x_m = x.mean(axis=0)
    x_cov = np.dot((x - x_m).T, x - x_m) / (x.shape[0])
    Dn = np.diag(1 / np.sqrt(np.diag(x_cov)))
    
    xx = x.T @ x
    xy = x.T @ y
    
    ridge = 1
    
    xx = Dn @ xx @ Dn
    xy = Dn @ xy
    beta_raw = Dn @ np.linalg.solve(xx + np.eye(len(xx)) * ridge, xy)
    f_raw = x @ beta_raw
    
    model = Pipeline([("scaler", StandardScaler(with_mean=False)), ("regression", Ridge(ridge, fit_intercept=False))])
    trained_model = model.fit(x, y)
    f_ml = trained_model.predict(x)
    
    print(f_ml[:3] / f_raw[:3])
    
    [[1. 1.]
     [1. 1.]
     [1. 1.]]
    

    【讨论】:

    • 谢谢!这是有道理的。假设我想让 sklearn 管道类似于原始方式,而不是相反。有没有简单的方法可以做到这一点?
    猜你喜欢
    • 2018-12-29
    • 2019-02-17
    • 2017-10-04
    • 2018-03-15
    • 2020-07-16
    • 2016-08-25
    • 2021-05-11
    • 2021-06-17
    • 2014-05-11
    相关资源
    最近更新 更多