【问题标题】:What's the full specification for implementing a custom scikit-learn estimator?实现自定义 scikit-learn 估计器的完整规范是什么?
【发布时间】:2018-05-13 04:43:05
【问题描述】:

我正在滚动自己的预测器,并希望像使用任何 scikit 例程(例如 RandomForestRegressor)一样使用它。我有一个包含fitpredict 方法的类,它们似乎工作正常。但是,当我尝试使用一些 scikit 方法(例如交叉验证)时,会出现如下错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
    for function, args, kwargs in iterable:
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
    % (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.

我看到它希望我实现一些方法(大概是 get_params 以及可能是 set_paramsscore),但我不确定制作这些方法的正确规范是什么。有没有关于这个主题的一些信息?谢谢。

【问题讨论】:

标签: python scikit-learn


【解决方案1】:

完整的说明在scikit-learn docs 中提供,API 背后的原理在this paper by yours truly et al. 中列出。简而言之,除了fit,您需要的估计器是get_paramsset_params,它们返回(作为dict)并设置(来自kwargs)估计器的超参数,即学习算法本身的参数(而不是它学习的数据参数)。这些参数应与__init__ 参数匹配。

这两种方法都可以通过继承sklearn.base中的类来获得,但是如果你不希望你的代码依赖于scikit-learn,你可以自己提供它们。

请注意,输入验证应在fit 中完成,而不是在构造函数中,因为否则您仍然可以在set_params 中设置无效参数并让fit 以意想不到的方式失败。

【讨论】:

  • 看来get_params 现在也应该有一个布尔值deep 参数。
  • 如果需要克隆,为什么不只需要对象本身定义的clone 方法?
猜你喜欢
  • 2020-09-25
  • 2018-04-08
  • 2020-03-07
  • 2019-08-18
  • 2013-06-04
  • 2017-03-09
  • 2021-03-04
  • 2019-06-13
  • 2021-01-02
相关资源
最近更新 更多