【问题标题】:Error while passing parameters for DecisionTreeClassifier为 DecisionTreeClassifier 传递参数时出错
【发布时间】:2018-02-04 12:49:44
【问题描述】:

我正在尝试使用字符串中的参数的 DecisionTreeClassifier。

 print d    # d= 'max_depth=100'
 clf = DecisionTreeClassifier(d)
 clf.fit(X[:3000,], labels[:3000])

在这种情况下,我遇到了错误。如果我使用clf = DecisionTreeClassifier(max_depth=100),它可以正常工作。

Traceback (most recent call last):
  File "train.py", line 120, in <module>
    grid_search_generalized(X, labels, {"max_depth":[i for i in range(100, 200)]})
  File "train.py", line 51, in grid_search_generalized
    clf.fit(X[:3000,], labels[:3000])
  File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 790, in fit
    X_idx_sorted=X_idx_sorted)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 326, in fit
    criterion = CRITERIA_CLF[self.criterion](self.n_outputs_,
KeyError: 'max_depth=100'

【问题讨论】:

    标签: python scikit-learn decision-tree


    【解决方案1】:

    没有在 DecisionTreeClassifier 函数中定义关键字变量参数。 max_depth 可以作为关键字参数传递。请尝试以下代码:

    d= 'max_depth=100'
    arg = dict([d.split("=")])
    i = int(next(iter(arg.values())))
    k = next(iter(arg.keys()))
    clf = DecisionTreeClassifier(max_depth=args['max_depth'])
    clf.fit(X[:3000,], labels[:3000])
    

    输出:

    DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=100,
                           max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=2,
                           min_weight_fraction_leaf=0.0, presort=False,
                           random_state=None, splitter='best')
    

    【讨论】:

      【解决方案2】:

      您将参数作为字符串对象而不是作为可选参数传递。
      如果你真的不得不用这个字符串调用构造函数,你可以使用这个代码:

       arg = dict([d.split("=")])
       clf = DecisionTreeClassifier(**arg)
      

      您可以在此链接中阅读有关解包参数的更多信息
      Passing a dictionary to a function in python as keyword parameters

      【讨论】:

      • Traceback (most recent call last): File "train.py", line 121, in &lt;module&gt; grid_search_generalized(X, labels, {"max_depth":[100]}) File "train.py", line 52, in grid_search_generalized clf.fit(X[:3000,], labels[:3000]) File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 790, in fit X_idx_sorted=X_idx_sorted) File "/usr/local/lib/python2.7/dist-packages/sklearn/tree/tree.py", line 326, in fit criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, TypeError: unhashable type: 'dict' 现在出现此错误。
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2016-10-08
      • 2014-06-21
      • 1970-01-01
      • 1970-01-01
      • 2021-02-10
      • 1970-01-01
      • 2013-06-12
      相关资源
      最近更新 更多