【问题标题】:Scikit-learn custom estimator "Invalid parameter for estimator" errorScikit-learn 自定义估计器“估计器的参数无效”错误
【发布时间】:2020-09-25 19:33:08
【问题描述】:

我正在为我的大学项目使用不同的投票方案为 SVM 集成实现自定义分类器。我的估算器代码:

svm_possible_args = {"C", "kernel", "degree", "gamma", "coef0", "shrinking", "probability", "tol", "cache_size",
                     "class_weight", "max_iter", "decision_function_shape", "break_ties"}

bagging_possible_args = {"n_estimators", "max_samples", "max_features", "bootstrap", "bootstrap_features",
                         "oob_score", "warm_start", "n_jobs"}

common_possible_args = {"random_state", "verbose"}


class SVMEnsemble(BaggingClassifier):
    def __init__(self, kernel="linear", voting_method=None, **kwargs):
        if voting_method not in {None, "hard", "soft"}:
            raise ValueError(f"voting_method {voting_method} is not recognized.")

        svm_args = dict()
        bagging_args = dict()
        for arg_name, arg_val in kwargs.items():
            if arg_name in svm_possible_args:
                svm_args[arg_name] = arg_val
            elif arg_name in bagging_possible_args:
                bagging_args[arg_name] = arg_val
            elif arg_name in common_possible_args:
                svm_args[arg_name] = arg_val
                bagging_args[arg_name] = arg_val
            else:
                raise ValueError(f"argument {voting_method} is not recognized.")

        probability = True if voting_method == "soft" else False
        svm_args = dict() if not svm_args else svm_args
        base_estimator = SVC(kernel=kernel, probability=probability, **svm_args)

        super().__init__(base_estimator=base_estimator, **bagging_args)
        self.voting_method = voting_method

    def predict(self, X):
        if self.voting_method in {None, "hard"}:
            return super().predict(X)
        elif self.voting_method == "soft":
            probabilities = np.zeros((X.shape[0], self.classes_.shape[0]))
            for estimator in self.estimators_:
                estimator_probabilities = estimator.predict_proba(X)
                probabilities += estimator_probabilities
            return self.classes_[probabilities.argmax(axis=1)]
        else:
            raise ValueError(f"voting_method {self.voting_method} is not recognized.")

我想从BaggingClassifier 继承大部分功能并插入SVC。用户应该能够指定 SVM 和 bagging 超参数,所以我使用了 for loop 和 svm_possible_args 等来过滤传递给 SVCBaggingClassifier 的参数。参数集几乎是可分离的(它们只有random_stateverbose 共同,这不是问题)。

我正在尝试使用GridSearchCV 找到最佳超参数:

def get_best_ensemble(X_train, y_train):
    parameters = {
        "voting_method": ["hard", "soft"],

        "max_samples": np.linspace(0.5, 1, 6, endpoint=True).round(1),
        "max_features": [0.7, 0.8, 0.9, 1],
        "n_estimators": [5, 10, 15],

        "kernel": ["linear", "poly", "rbf", "sigmoid"],
        "C": [0.01, 0.1, 0.5, 1, 10],
        "gamma": [0.01, 0.1, 0.3, 0.6, 1]
    }

    model = SVMEnsemble()
    grid = GridSearchCV(model, parameters, verbose=2, cv=5, n_jobs=-1)
    grid.fit(X_train, y_train)

    print("Best hyperparameters:")
    print(grid.best_params_)

    return grid.best_estimator_

我收到以下错误:

ValueError: Invalid parameter C for estimator SVMEnsemble(kernel=None, voting_method=None). Check the list of available parameters with `estimator.get_params().keys()`.

使用print(model.get_params().keys()) 我得到dict_keys(['kernel', 'voting_method'])。这是否意味着我必须在__init__SVMEnsembleGridSearchCV 中明确列出SVCBaggingClassifier所有 参数才能“看到”它们并实际工作?或者有什么更清洁的解决方案?

【问题讨论】:

  • 您是否必须自己实现 SVM 集成,或者您可以只使用 sklearn?您尝试做的事情似乎是多余的,因为您已经可以使用 BaggingClassifier 制作 SVM 集成,只需将 base_estimator 指定为 SVC。
  • @BlueSkyz 我不得不这样做,这是为了大学项目,它是强加给我的。此外,毕竟这样做对于 GridSearchCV 优化来说更容易。

标签: python scikit-learn


【解决方案1】:

您可以覆盖get_paramsset_params 方法,或者将实际的SVM 对象作为初始化参数。您需要做一些事情,以便当网格搜索尝试set_params 时,您实例中的estimator 会正确更新(不仅仅是实例中的参数;请注意__init__ 不会重新运行)。

有一些关于使继承类参数发现更容易的讨论,但它很棘手,并且无法解决第二个问题:
https://github.com/scikit-learn/scikit-learn/issues/13555

【讨论】:

  • 谢谢!这可能是做到这一点的“正确”方式。最后,我刚刚明确地编写了我必须使用 GridSearchCV 优化的那几个超参数,但总的来说,这些覆盖听起来很棒。
猜你喜欢
  • 2021-08-17
  • 2019-08-18
  • 2013-06-04
  • 2019-01-15
  • 2018-04-05
  • 1970-01-01
  • 1970-01-01
  • 2021-01-02
  • 2018-06-28
相关资源
最近更新 更多