【发布时间】:2021-03-29 11:46:46
【问题描述】:
设置:
- 我有一个包含一些 NaN 的数据集。
- 我想拟合 LogisticRegression 并将这些预测输入 HistGradiantBoostingClassifier
- 我希望 HistGradiantBoostingClassifier 使用其自己的内部 NaN 处理
首先,Debug 类可以帮助查看发生了什么
from sklearn.base import BaseEstimator, TransformerMixin
import numpy as np
class Debug(BaseEstimator, TransformerMixin):
def __init__(self, msg='DEBUG'):
self.msg=msg
def transform(self, X):
self.shape = X.shape
print(self.msg)
print(f'Shape: {self.shape}')
print(f'NaN count: {np.count_nonzero(np.isnan(X))}')
return X
def fit(self, X, y=None, **fit_params):
return self
现在是我的管道
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import StackingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
data = load_breast_cancer()
X = data['data']
y = data['target']
X[0, 0] = np.nan # make a NaN
lr_pipe = make_pipeline(
Debug('lr_pipe START'),
SimpleImputer(),
StandardScaler(),
LogisticRegression()
)
pipe = StackingClassifier(
estimators=[('lr_pipe', lr_pipe)],
final_estimator=HistGradientBoostingClassifier(),
passthrough=True,
cv=2,
verbose=10
)
pipe.fit(X, y)
应该发生什么:
- LogisticRegression 适合整个数据集以供以后预测(此处未使用)
- 为了将特征输入 HGB,LogisticRegression 需要
cross_val_predict,我指定了 2 折。我应该看到lr_pipe被称为两次,以便生成折叠预测。
实际会发生什么:
lr_pipe START
Shape: (569, 30)
NaN count: 1
lr_pipe START
Shape: (284, 30)
NaN count: 0
lr_pipe START
Shape: (285, 30)
NaN count: 1
lr_pipe START
Shape: (285, 30)
NaN count: 1
lr_pipe START
Shape: (284, 30)
NaN count: 0
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s
[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s remaining: 0.0s
[Parallel(n_jobs=1)]: Done 2 out of 2 | elapsed: 0.0s finished
为什么lr_pipe 被调用了 5 次?我应该看到它被调用了 3 次。
【问题讨论】:
标签: python machine-learning scikit-learn