【问题标题】:Data not persistent in scikit-learn transformers数据在 scikit-learn 转换器中不持久
【发布时间】:2017-01-25 08:45:44
【问题描述】:

我想将其他数据传递给 scikit-learn 中的转换器:

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import Pipeline
import numpy as np
from sklearn.model_selection import GridSearchCV

class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

data = np.random.rand(20,20)
data2 = np.random.rand(6,6)
y = np.array([1, 2, 3, 1, 2, 3, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 3, 3, 3, 3])

pipe = Pipeline(steps=[('myt', myTransformer(data2)), ('randforest', RandomForestClassifier())])
params = {"randforest__n_estimators": [100, 1000]}
estimators = GridSearchCV(pipe, param_grid=params, verbose=True)
estimators.fit(data, y)

但是,当在 scikit-learn 管道中使用时,它似乎消失了

我从 init 方法中的 print 得到None。我该如何解决?

【问题讨论】:

  • 你确定传递时mydata不是None吗?
  • 是的,不是没有
  • 我想您应该将.fit 添加到估计器中以使错误出现。
  • 好的,谢谢。刚刚添加

标签: python machine-learning scikit-learn


【解决方案1】:

发生这种情况是因为 sklearn 以非常具体的方式处理估算器。一般来说,它会为网格搜索创建一个新的类实例,并将参数传递给构造函数。发生这种情况是因为 sklearn 有自己的 clone 操作 (defined in base.py) 它接受您的估算器类,获取参数(由get_params 返回)并将其传递给您的类的构造函数

klass = estimator.__class__
new_object_params = estimator.get_params(deep=False)
for name, param in six.iteritems(new_object_params):
    new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params) 

为了支持您的对象必须覆盖get_params(deep=False)方法,该方法应返回字典,该字典将传递给构造函数

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
class myTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, my_np_array):
        self.data = my_np_array
        print self.data

    def transform(self, X):
        return X

    def fit(self, X, y=None):
        return self

    def get_params(self, deep=False):
        return {'my_np_array': self.data}

将按预期工作。

【讨论】:

  • 不适用 int 和 float,是吗?我是否必须处理实例本地的所有变量(self.variable)?例如,我需要它来获得最佳估算器
  • “不应用 int 和 float”?我要说的是您需要一个 get_params 方法,它将提供克隆转换器实例所需的一切。你可以将它们“存储”到任何你想要的地方,get_params 需要返回它,但是你可以将它们作为静态属性、全局变量,任何你想要的。
  • 很抱歉再次打扰您,但我不确定我是否理解。我不太明白为什么参数不可见,因为网格搜索的目的是尝试不同的设置并找到最佳设置。那么估算器在官方图书馆中是如何工作的呢?谢谢,
  • 这与可见性无关。这就是 scikit-learn 开发人员决定克隆对象的方式(它有优点也有缺点,但这是他们的设计决定)。这里没有什么要补充的。他们(估计器)都实现了这个接口——他们有 get_params,被调用——返回复制对象所需的一切。还有 set_param 有对称意义等。
猜你喜欢
  • 2015-10-12
  • 2015-03-07
  • 2020-02-04
  • 2017-12-17
  • 2021-10-25
  • 2018-03-15
  • 2016-11-01
  • 2014-04-29
  • 2018-06-01
相关资源
最近更新 更多