【发布时间】:2019-08-12 12:09:38
【问题描述】:
我很难理解 GridsearchCV 结合自定义转换的真正工作原理。
我想要实现的目标: 我想实现一个 Transformer/Estimator,它允许根据参数在某些方法之间切换,因为我想在 gridsearch 中包含这些不同的方法。
示例:我有一个名为 Scaler() 的自定义 Transformer,它可以选择 MinMaxScaler 或 StandardScaler。 (只是为了简单)
class Scaling():
def __init__(self, **params):
self.method=None
self.params = {}
print("INITIATING CLASS")
def fit(self, X, y=None):
return self
def transform(self, X):
print("TRANSFORMING", X)
if self.method == "minMax":
self.scaler =
MinMaxScaler(feature_range=self.params["feature_range"])
elif self.method == "std":
self.scaler = StandardScaler()
return self.scaler.fit_transform(X)
def get_params(self, **params):
return {**StandardScaler().get_params(), **MinMaxScaler().get_params(),
**{"method":""} }
def set_params(self, **params):
print("SETTING PARAMETER")
self.method = params["method"]
self.params = params
这是我的示例数据:
data = np.array([1,2,3,4,5,6,7,8,9,10]).reshape(-1,1)
y = [2,3,4,5,6,7,8,9,10,11]
我的管道:
p = Pipeline([('scaler', Scaling()),
('model', LinearRegression())])
我的参数网格和网格搜索
hyperparams = {
'scaler__feature_range' : [(0,1), (-100,10)],
'scaler__method':["minMax"]
}
clf = GridSearchCV(p,hyperparams, cv=2)
clf.fit(data, y)
它确实有效,但我对打印日志感到很困惑:
INITIATING CLASS
INITIATING CLASS
INITIATING CLASS
SETTING PARAMETER
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
TRANSFORMING [[1][2][3][4][5]]
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
INITIATING CLASS
SETTING PARAMETER
TRANSFORMING [[1][2][3][4][5]]
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
TRANSFORMING [[1][2][3][4][5]]
INITIATING CLASS
SETTING PARAMETER
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
TRANSFORMING [[1][2][3][4][5]]
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
INITIATING CLASS
SETTING PARAMETER
TRANSFORMING [[1][2][3][4][5]]
TRANSFORMING [[ 6][ 7][ 8][ 9][10]]
TRANSFORMING [[1][2][3][4][5]]
INITIATING CLASS
SETTING PARAMETER
TRANSFORMING [[ 1][ 2 [ 3][ 4][ 5][ 6][ 7][ 8][ 9][10]]
我已设置 cv=2。我希望它是这样的。
- 实例化所有变形金刚
- 根据Gridsearch设置参数
- 通过管道传递 train-fold
- 通过管道传递测试折叠
- 重复 所以我预计对transformer方法有8次调用,因为我们需要一个用于火车,一个用于测试折叠。由于 cv=2,我们这样做了 2 次,并且因为我们在参数网格中为 feature_range 定义了两个不同的值,所以我们必须将其乘以 2,因此是 8。出了什么问题?
但是为什么我的 Scaling 类有这么多的调用呢? 如何解释这种日志顺序? 为什么最后的全序列被转换了?
【问题讨论】:
标签: python scikit-learn pipeline grid-search