【问题标题】:Color of the node of tree with graphviz using class_names使用 class_names 使用 graphviz 的树节点的颜色
【发布时间】:2017-08-30 01:46:33
【问题描述】:

扩展上一个问题: Changing colors for decision tree plot created using export graphviz

如何根据主要类(虹膜的种类)而不是二元区分为树的节点着色?这应该需要 iris.target_names(描述​​类的字符串)和 iris.target(类)的组合。

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

【问题讨论】:

    标签: python-3.x scikit-learn decision-tree graph-visualization pydot


    【解决方案1】:

    示例中的代码看起来很熟悉,因此很容易修改:)

    对于每个节点 Graphviz 告诉我们每个组有多少样本,即它是混合种群还是决策树。我们可以提取此信息并用于获取颜色。

    values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
    

    或者,您可以将GraphViz 节点映射回sklearn 节点:

    values = clf.tree_.value[int(node.get_name())][0]
    

    我们只有 3 个类别,因此每个类别都有自己的颜色(红、绿、蓝),混合人群根据他们的分布获得混合颜色。

    values = [int(255 * v / sum(values)) for v in values]
    color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
    

    我们现在可以很好地看到分离,它越绿,我们拥有的第二类就越多,蓝色和第三类也是如此。


    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf,
                                    feature_names=iris.feature_names,
                                    out_file=None,
                                    filled=True,
                                    rounded=True,
                                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    nodes = graph.get_node_list()
    
    for node in nodes:
        if node.get_label():
            values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
            values = [int(255 * v / sum(values)) for v in values]
            color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
            node.set_fillcolor(color)
    
    graph.write_png('colored_tree.png')
    

    3 个以上类的通用解决方案,仅对最终节点着色。

    colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')
    
    for node in nodes:
        if node.get_name() not in ('node', 'edge'):
            values = clf.tree_.value[int(node.get_name())][0]
            #color only nodes where only one class is present
            if max(values) == sum(values):    
                node.set_fillcolor(colors[numpy.argmax(values)])
            #mixed nodes get the default color
            else:
                node.set_fillcolor(colors[-1])
    

    【讨论】:

    • 我的个人问题有四个类。对于 n 个类,你如何概括这个?
    • 我原本打算使用:colors = ('lightblue', 'lightyellow', 'lightgreen', 'lightred',) 然后在它们之间进行插值。
    • @MyopicVisage:这很有挑战性 :) 就我个人而言,我宁愿只对最后的节点进行着色,否则它会变成一棵圣诞树。我会再考虑一下。
    • 我同意。我的问题陈述旨在阅读,用支配类为节点着色,因此不需要插值。
    • @MyopicVisage:查看更新的答案以获得一般解决方案
    【解决方案2】:

    很好的答案伙计们。只是为了添加到@Maximilian Peters 的答案。可以为特定颜色识别叶节点的另一件事是检查 split_criteria(threshold) 值。由于叶节点没有子节点,因此也没有拆分标准。

    https://github.com/scikit-learn/scikit-learn/blob/a24c8b464d094d2c468a16ea9f8bf8d42d949f84/sklearn/tree/_tree.pyx
    TREE_UNDEFINED = -2 
    thresholds = clf.tree_.threshold
    for node in nodes:
        if node.get_name() not in ('node', 'edge'):
            value = clf.tree_.value[int(node.get_name())][0]
            # color only nodes where only one class is present or if it is a leaf 
            # node
            if max(values) == sum(values) or 
                thresholds[int(node.get_name())] == TREE_UNDEFINED:    
                    node.set_fillcolor(colors[numpy.argmax(value)])
            # mixed nodes get the default color
            else:
                node.set_fillcolor(colors[-1])
    

    与问题不完全相关,但添加更多信息以防对其他人有帮助。 继续理解基于树的分类器的决策树桩的想法,Skater 增加了对使用树代理来总结所有形式的基于树的模型的支持。在此处查看示例。

    https://github.com/datascienceinc/Skater/blob/master/examples/rule_list_notebooks/explanation_using_tree_surrogate.ipynb

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-12-08
      • 2020-04-10
      • 1970-01-01
      • 1970-01-01
      • 2015-10-28
      • 2016-12-17
      • 2017-08-10
      相关资源
      最近更新 更多