【发布时间】:2021-10-30 16:31:11
【问题描述】:
我正在尝试使用 gridsearch 和管道构建决策树,但是当我尝试使用 graphviz 导出图像时出现错误。我上网查了一下,什么都找不到;一个潜在的问题是如果我不使用 best_estimator_ 实例,但我在这种情况下使用了。
除了导出图表部分外,一切正常(获取准确性和其他指标)。
def TreeOpt(X, y):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
std_scl = StandardScaler()
dec_tree = tree.DecisionTreeClassifier()
pipe = Pipeline(steps=[('std_slc', std_scl),
('dec_tree', dec_tree)])
criterion = ['gini', 'entropy']
max_depth = list(range(1,15))
parameters = dict(dec_tree__criterion=criterion,
dec_tree__max_depth=max_depth)
tree_gs = GridSearchCV(pipe, parameters)
tree_gs.fit(X_train, y_train)
export_graphviz(
tree_gs.best_estimator_,
out_file=("dec_tree.dot"),
feature_names=None,
class_names=None,
filled=True)
但我明白了
<ipython-input-2-bb91ec6ba0d9> in <module>
37 filled=True)
38
---> 39 DecTreeOptimizer(X = df.drop(['quality'], axis=1), y = df.quality)
40
<ipython-input-2-bb91ec6ba0d9> in DecTreeOptimizer(X, y)
30 print("Best score: " + str(tree_GS.best_score_))
31
---> 32 export_graphviz(
33 tree_GS.best_estimator_,
34 out_file=("dec_tree.dot"),
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\tree\_export.py in export_graphviz(decision_tree, out_file, max_depth, feature_names, class_names, label, filled, leaves_parallel, impurity, node_ids, proportion, rotate, rounded, special_characters, precision)
767 """
768
--> 769 check_is_fitted(decision_tree)
770 own_file = False
771 return_string = False
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~\AppData\Local\Programs\Python\Python39\lib\site-packages\sklearn\utils\validation.py in check_is_fitted(estimator, attributes, msg, all_or_any)
1096
1097 if not attrs:
-> 1098 raise NotFittedError(msg % {'name': type(estimator).__name__})
1099
1100
NotFittedError: This Pipeline instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.```
【问题讨论】:
-
请编辑您的问题以包含完整的错误跟踪。
-
@desertnaut,我现在有了。
标签: machine-learning scikit-learn graphviz decision-tree