【问题标题】:GridSearchCV: "TypeError: 'StratifiedKFold' object is not iterable"GridSearchCV:“TypeError:‘StratifiedKFold’对象不可迭代”
【发布时间】:2017-03-08 12:53:37
【问题描述】:

我想在一个 RandomForestClassifier 中执行 GridSearchCV,但是数据不平衡,所以我使用了 StratifiedKFold:

from sklearn.model_selection import StratifiedKFold
from sklearn.grid_search import GridSearchCV
from sklearn.ensemble import RandomForestClassifier

param_grid = {'n_estimators':[10, 30, 100, 300], "max_depth": [3, None],
          "max_features": [1, 5, 10], "min_samples_leaf": [1, 10, 25, 50], "criterion": ["gini", "entropy"]}

rfc = RandomForestClassifier()

clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train)

但我得到一个错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-597-b08e92c33165> in <module>()
     9 rfc = RandomForestClassifier()
     10 
---> 11 clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train)

c:\python34\lib\site-packages\sklearn\grid_search.py in fit(self, X, y)
    811 
    812         """
--> 813         return self._fit(X, y, ParameterGrid(self.param_grid))

c:\python34\lib\site-packages\sklearn\grid_search.py in _fit(self, X, y, parameter_iterable)
    559                                     self.fit_params, return_parameters=True,
    560                                     error_score=self.error_score)
--> 561                 for parameters in parameter_iterable
    562                 for train, test in cv)

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self, iterable)
    756             # was dispatched. In particular this covers the edge
    757             # case of Parallel used with an exhausted iterator.
--> 758             while self.dispatch_one_batch(iterator):
    759                 self._iterating = True
    760             else:

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in dispatch_one_batch(self, iterator)
    601 
    602         with self._lock:
--> 603             tasks = BatchedCalls(itertools.islice(iterator, batch_size))
    604             if len(tasks) == 0:
    605                 # No more tasks available in the iterator: tell caller to stop.

c:\python34\lib\site-packages\sklearn\externals\joblib\parallel.py in __init__(self, iterator_slice)
    125 
    126     def __init__(self, iterator_slice):
--> 127         self.items = list(iterator_slice)
    128         self._size = len(self.items)

c:\python34\lib\site-packages\sklearn\grid_search.py in <genexpr>(.0)
    560                                     error_score=self.error_score)
    561                 for parameters in parameter_iterable
--> 562                 for train, test in cv)
    563 
    564         # Out is a list of triplet: score, estimator, n_test_samples

TypeError: 'StratifiedKFold' object is not iterable

当我写 cv=StratifiedKFold(y_train) 时,我有 ValueError: The number of folds must be of Integral type. 但是当我写 `cv=5 时,它可以工作。

我不明白 StratifiedKFold 有什么问题

【问题讨论】:

    标签: pandas scikit-learn grid-search sklearn-pandas


    【解决方案1】:

    API 在最新版本中发生了变化。您曾经传递 y,而现在在创建 stratifiedKFold 对象时只传递数字。你稍后通过 y。

    【讨论】:

    • 我写cv=StratifiedKFold(10) 得到TypeError: 'StratifiedKFold' object is not iterable 我什么时候应该通过y?
    • 在当前版本中,您导入 sklearn.model_selection.StratifiedKFold。然后你可以做 cv=StratifiedKFold(10) 应该没有错误。但是,也许您是从以前的模块导入的,该模块在 20 版之前仍然存在以实现兼容性。
    • 我能再问一个问题吗?我从这个站点下载了lfd.uci.edu/~gohlke/pythonlibs/#scikit-learn 文件 scikit_learn-0.18-cp34-cp34m-win32.whl,安装了它,但现在我得到了ImportError: DLL load failed: %1 is not a valid Win32 application. 。怎么了?
    • 可能在某处缺少依赖项。最简单的方法是下载 anaconda。然后它就可以工作了。
    【解决方案2】:

    看来cv=StratifiedKFold()).fit(X_train, y_train)应该改成cv=StratifiedKFold()).split(X_train, y_train).

    【讨论】:

    • 这与错误无关。这一行: clf = GridSearchCV(rfc, param_grid=param_grid, cv=StratifiedKFold()).fit(X_train, y_train) 只是定义了对象 clf 然后它调用 fit 方法来训练/拟合 clf。
    • @rll 还提到 fit 应该替换为 split。
    【解决方案3】:

    这里的问题是其他答案中提到的 API 更改,但是答案可能更明确。

    cv 参数文档指出:

    cv : int,交叉验证生成器或可迭代的,可选的

    确定交叉验证拆分策略。可能的输入 简历是:

    • 无,使用默认的3折交叉验证,整数, 指定折叠数。

    • 一个对象被用作 交叉验证生成器。

    • 可迭代的产生训练/测试拆分。

    对于整数/无输入,如果 y 是二进制或多类,StratifiedKFold 用过的。如果估计器是分类器,或者 y 既不是二元的,也不是 多类,使用 KFold。

    所以,无论cross validation strategy 使用什么,只需要按照建议使用函数split 提供生成器:

    kfolds = StratifiedKFold(5)
    clf = GridSearchCV(estimator, parameters, scoring=qwk, cv=kfolds.split(xtrain,ytrain))
    clf.fit(xtrain, ytrain)
    

    【讨论】:

      【解决方案4】:

      我遇到了完全相同的问题。对我有用的解决方案是替换

      from sklearn.grid_search import GridSearchCV
      

      from sklearn.model_selection import GridSearchCV
      

      那么它应该可以正常工作了。

      【讨论】:

        猜你喜欢
        • 2013-09-01
        • 2017-08-27
        • 2018-10-10
        • 2021-12-13
        • 2019-02-20
        • 2020-03-27
        • 2018-12-12
        • 2018-07-16
        • 2011-09-12
        相关资源
        最近更新 更多