【问题标题】:How to display the path of a Decision Tree for test samples?如何显示测试样本的决策树路径?
【发布时间】:2019-09-16 14:17:48
【问题描述】:

我正在使用来自 scikit-learn 的DecisionTreeClassifier 对一些多类数据进行分类。我发现了许多描述如何显示决策树路径的帖子,例如 hereherehere。但是,它们都描述了如何为训练数据显示树。这是有道理的,因为export_graphviz 只需要一个拟合模型。

我的问题是如何可视化测试样本上的树(最好是export_graphviz)。 IE。在用clf.fit(X[train], y[train]) 拟合模型,然后通过clf.predict(X[test]) 预测测试数据的结果后,我想可视化用于预测样本X[test] 的决策路径。有没有办法做到这一点?

编辑:

我看到可以使用decision_path 打印路径。如果有办法让 DOT 输出到 export_graphviz 来显示它,那就太好了。

【问题讨论】:

    标签: python scikit-learn visualization decision-tree


    【解决方案1】:

    为了获取决策树中特定样本的路径,您可以使用decision_path。它返回一个稀疏矩阵,其中包含所提供样本的决策路径。

    然后可以使用这些决策路径对通过pydot 生成的树进行着色/标记。这需要覆盖颜色和标签(这会导致代码有点难看)。

    备注

    • decision_path 可以从训练集或新值中抽取样本
    • 您可以随意使用颜色,并根据样本数量或可能需要的任何其他可视化更改颜色

    示例

    在下面的示例中,访问的节点以绿色着色,所有其他节点均为白色。

    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                                    feature_names=iris.feature_names,
                                    class_names=iris.target_names,
                                    filled=True, rounded=True,
                                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    
    # empty all nodes, i.e.set color to white and number of samples to zero
    for node in graph.get_node_list():
        if node.get_attributes().get('label') is None:
            continue
        if 'samples = ' in node.get_attributes()['label']:
            labels = node.get_attributes()['label'].split('<br/>')
            for i, label in enumerate(labels):
                if label.startswith('samples = '):
                    labels[i] = 'samples = 0'
            node.set('label', '<br/>'.join(labels))
            node.set_fillcolor('white')
    
    samples = iris.data[129:130]
    decision_paths = clf.decision_path(samples)
    
    for decision_path in decision_paths:
        for n, node_value in enumerate(decision_path.toarray()[0]):
            if node_value == 0:
                continue
            node = graph.get_node(str(n))[0]            
            node.set_fillcolor('green')
            labels = node.get_attributes()['label'].split('<br/>')
            for i, label in enumerate(labels):
                if label.startswith('samples = '):
                    labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)
    
            node.set('label', '<br/>'.join(labels))
    
    filename = 'tree.png'
    graph.write_png(filename)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-12-26
      • 2019-06-01
      • 2021-11-29
      • 2017-04-27
      • 2015-12-29
      • 2019-01-07
      • 2016-08-10
      • 2017-09-17
      相关资源
      最近更新 更多