【问题标题】:plot decision paths from all of the test sample绘制所有测试样本的决策路径
【发布时间】:2021-04-12 22:10:55
【问题描述】:

有这段代码来自:

How to display the path of a Decision Tree for test samples?

基本上,它在决策树图上绘制样本的决策路径,以了解特定预测是如何做出的

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)

# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
    if node.get_attributes().get('label') is None:
        continue
    if 'samples = ' in node.get_attributes()['label']:
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = 0'
        node.set('label', '<br/>'.join(labels))
        node.set_fillcolor('white')

samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)

for decision_path in decision_paths:
    for n, node_value in enumerate(decision_path.toarray()[0]):
        if node_value == 0:
            continue
        node = graph.get_node(str(n))[0]            
        node.set_fillcolor('green')
        labels = node.get_attributes()['label'].split('<br/>')
        for i, label in enumerate(labels):
            if label.startswith('samples = '):
                labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)

        node.set('label', '<br/>'.join(labels))

filename = 'tree.png'
graph.write_png(filename)

我想做的是在我的 Jupiter notebook 中的不同图中绘制所有样本决策路径。我应该在代码中添加什么?

【问题讨论】:

    标签: graphviz decision-tree


    【解决方案1】:

    由于问题没有提供具体的预期结果,我假设您想要绘制分类结果的决策路径部分。

    其中一种解决方案是添加一层 for 循环来实现它。但是,它可能会影响程序的性能。因此,请谨慎使用。

    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                                    feature_names=iris.feature_names,
                                    class_names=iris.target_names,
                                    filled=True, rounded=True,
                                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    
    # empty all nodes, i.e.set color to white and number of samples to zero
    for node in graph.get_node_list():
        if node.get_attributes().get('label') is None:
            continue
        if 'samples = ' in node.get_attributes()['label']:
            labels = node.get_attributes()['label'].split('<br/>')
            for i, label in enumerate(labels):
                if label.startswith('samples = '):
                    labels[i] = 'samples = 0'
            node.set('label', '<br/>'.join(labels))
            node.set_fillcolor('white')
    
    # samples = iris.data[129:130]
    samples = iris.data[120:130] # <-- Classifying 10 data
    decision_paths = clf.decision_path(samples)
    
    for decision_path in decision_paths:
        for path in decision_path.toarray(): # <-- Adding one more layer of for loop to loop each path
        # for n, node_value in enumerate(decision_path.toarray()[0]):
            for n, node_value in enumerate(path):
                if node_value == 0:
                    continue
                node = graph.get_node(str(n))[0]
                node.set_fillcolor('green')
                labels = node.get_attributes()['label'].split('<br/>')
                for i, label in enumerate(labels):
                    if label.startswith('samples = '):
                        labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)
    
                node.set('label', '<br/>'.join(labels))
    
    # display it inline
    display_png(Image(graph.create_png()))
    
    # or save it as png
    filename = 'tree.png'
    graph.write_png(filename)
    

    结果:

    【讨论】:

      猜你喜欢
      • 2019-09-16
      • 2019-06-01
      • 2019-01-07
      • 2017-03-31
      • 2013-02-11
      • 2014-02-26
      • 2014-06-30
      • 2018-07-30
      • 1970-01-01
      相关资源
      最近更新 更多