【发布时间】:2017-08-30 01:46:33
【问题描述】:
扩展上一个问题: Changing colors for decision tree plot created using export graphviz
如何根据主要类(虹膜的种类)而不是二元区分为树的节点着色?这应该需要 iris.target_names(描述类的字符串)和 iris.target(类)的组合。
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()
colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)
for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))
for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('tree.png')
【问题讨论】:
标签: python-3.x scikit-learn decision-tree graph-visualization pydot