【问题标题】:Sklearn fit method error when building composed estimator构建组合估计器时的 Sklearn 拟合方法错误
【发布时间】:2021-01-18 22:03:26
【问题描述】:

我试图在 sklearn 中构建一个组合估计器;我现在发现 sklearn.compose.TransformedTargetRegressor 完全符合我的目标,但我仍然无法复制它,我很好奇为什么。

我得到的错误:

AssertionError: Estimator TransformedSkModel should not change or mutate  the parameter model from LinearRegression() to LinearRegression() during fit.

我的代码:

import numpy as np
from sklearn.base import BaseEstimator

class TransformedSkModel(BaseEstimator):
    
    def __init__(self, model, transform_function, reverse_transform_function):
        
        # TODO: ideally, verify, based on specified ranges, that the transform
        # function and reverse transform function are compatible
        
        self.model = model
        self.transform_function = transform_function
        self.reverse_transform_function = reverse_transform_function
    
    def fit(self, X, y):
        
        # Trying to reproduce as sklearn compatible estimator, we must use the 
        # conventions:
            
        # The constructor in an estimator should only set attributes to the 
        # values the user passes as arguments. All computation should occur in 
        # fit, and if fit needs to store the result of a computation, it should 
        # do so in an attribute with a trailing underscore (_). This convention 
        # is what makes clone and meta-estimators such as GridSearchCV work.
        
        self.vectorized_transform_function_ =\
            np.vectorize(self.transform_function)
        self.vectorized_reverse_transform_function_ =\
            np.vectorize(self.reverse_transform_function)
            
        y_transformed = self.vectorized_transform_function_(y)
        self.model.fit(X, y_transformed)
        
        return self
        
    def predict(self, X):
        
        y_transformed = self.model.predict(X)
        y = self.vectorized_reverse_transform_function_(y_transformed)
        return y
    
    # def get_params(self, ):
        
    #     return self.model.get_params()
    
if __name__ == "__main__":
    
    from sklearn.utils.estimator_checks import check_estimator
    from sklearn.linear_model import LinearRegression
    
    lm = LinearRegression()
    id_func = lambda x:x
    test = TransformedSkModel(lm, id_func, id_func)
    check_estimator(test)

编辑:我使用版本 sklearn 0.24.0 和 python 版本 3.6.8

【问题讨论】:

  • 你使用的是什么版本的 sklearn?
  • 嘿@ctlr;我使用版本 sklearn 0.24.0 和 python 版本 3.6.8
  • 这似乎是一个问题here 如果我找到解决方案,我会发布答案。基本上,正在比较您的拟合模型参数和原始模型参数,并且拟合模型对象已经学习了导致不同哈希并引发错误的参数
  • 是的,这也是我的理解。我了解参数并不意味着通过 fit 方法修改;但我的次估计必须是。是否应该进行不同的标记?在我考虑的同时,查看sklearn.compose.TransformedTargetRegressor 的代码会有所帮助,因为它可以毫无问题地做我想做的事情
  • Here 正在克隆原始估计器,并且正在拟合克隆的估计器并用于预测。在此之后,由于 id_func 我收到了一些错误

标签: python scikit-learn


【解决方案1】:

所以从查看sklearn.compose.TransformedTargetRegressor(它可以满足我的要求)来看,关键似乎是使用sklearn.base.clone 复制我的model 并适应新的self.model_(带有下划线以匹配约定)。所以我的 fit 方法的新代码变成了:

def fit(self, X, y):
        
        # Trying to reproduce as sklearn compatible estimator, we must use the 
        # conventions:
            
        # The constructor in an estimator should only set attributes to the 
        # values the user passes as arguments. All computation should occur in 
        # fit, and if fit needs to store the result of a computation, it should 
        # do so in an attribute with a trailing underscore (_). This convention 
        # is what makes clone and meta-estimators such as GridSearchCV work.
        
        self.vectorized_transform_function_ =\
            np.vectorize(self.transform_function)
        self.vectorized_reverse_transform_function_ =\
            np.vectorize(self.reverse_transform_function)
            
        y_transformed = self.vectorized_transform_function_(y)
        self.model_ = clone(self.model)
        self.model_.fit(X, y_transformed)

现在我收到另一个与使用 np.vectorize 相关的错误,但这是另一个问题,我猜想可以在另一个问题中解决。

【讨论】:

    猜你喜欢
    • 2017-03-20
    • 2015-06-15
    • 2018-11-12
    • 2018-08-24
    • 2016-04-16
    • 1970-01-01
    • 2023-04-09
    • 2016-10-19
    • 2019-01-21
    相关资源
    最近更新 更多