【问题标题】:How to retrieve the full branch path leading to each leaf node of a sklearn Decision Tree?如何检索导致 sklearn 决策树的每个叶节点的完整分支路径?
【发布时间】:2021-05-23 15:00:45
【问题描述】:

我有这个决策树,我想从中提取每个分支。该图像是树的一部分,因为原始树要大得多,但它不能很好地适合单个图像。

我不想像打印树的规则

Rules used to predict sample 1400:

decision node 0 : (X[1400, 4] = 92.85714285714286) > 96.42856979370117)
decision node 4 : (X[1400, 3] = 45.03259584336583) > 53.49640464782715)

或喜欢:

The binary tree structure has 7 nodes and has the following tree structure:

node=0 is a split node: go to node 1 if 4 <= 96.42856979370117 else to node 4.
    node=1 is a split node: go to node 2 if 3 <= 96.42856979370117 else to node 3.
    node=4 is a split node: go to node 5 if 5 <= 0.28278614580631256 else to node 6.

我想要达到的目标是:

branch 0: x[4] <= 96.429,x[3]<=96.429,class=B,gini_score=0.5
branch 1: x[4] <= 96.429,x[3]>96.429,class=B,gini_score=0.021
branch 2: x[4] > 96.429,x[5]<=0.283,class=A,gini_score=0.092
branch 4: x[4] > 96.429,x[5]>0.283,class=A,gini_score=0.01

基本上,我试图通过类和基尼分数获取从顶部到叶节点(完整路径)的每个分支。我怎样才能做到这一点?

【问题讨论】:

    标签: python python-3.x scikit-learn decision-tree


    【解决方案1】:

    考虑到来自 sklearn docs 的 irist 数据集示例,我们遵循以下步骤。

    1.生成示例决策树

    代码取自docs

    from matplotlib import pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn.datasets import load_iris
    from sklearn.tree import DecisionTreeClassifier
    from sklearn import tree
    import numpy as np
    
    iris = load_iris()
    X = iris.data
    y = iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
    
    clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
    clf.fit(X_train, y_train)
    

    2。检索分支路径

    首先我们从树中检索以下值

    n_nodes = clf.tree_.node_count
    children_left = clf.tree_.children_left
    children_right = clf.tree_.children_right
    feature = clf.tree_.feature
    threshold = clf.tree_.threshold
    impurity = clf.tree_.impurity
    value = clf.tree_.value
    

    retrieve_branches 内部,我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们使用yield 语句返回分支路径。

    def retrieve_branches(number_nodes, children_left_list, children_right_list):
        """Retrieve decision tree branches"""
        
        # Calculate if a node is a leaf
        is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
        
        # Store the branches paths
        paths = []
        
        for i in range(number_nodes):
            if is_leaves_list[i]:
                # Search leaf node in previous paths
                end_node = [path[-1] for path in paths]
    
                # If it is a leave node yield the path
                if i in end_node:
                    output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
                    yield output
    
            else:
                
                # Origin and end nodes
                origin, end_l, end_r = i, children_left_list[i], children_right_list[i]
    
                # Iterate over previous paths to add nodes
                for index, path in enumerate(paths):
                    if origin == path[-1]:
                        paths[index] = path + [end_l]
                        paths.append(path + [end_r])
    
                # Initialize path in first iteration
                if i == 0:
                    paths.append([i, children_left[i]])
                    paths.append([i, children_right[i]])
    

    要调用retrieve_branches,只需传递n_nodeschildren_leftchildren_right 以及一个将存储和更新分支路径的空列表。最终显示如下。

    all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
    all_branches
    >>> 
    [[0, 1],
     [0, 2, 3, 5],
     [0, 2, 3, 6, 7],
     [0, 2, 3, 6, 8],
     [0, 2, 4, 9],
     [0, 2, 4, 10]]
    

    3。按分支的路径、价值和基尼

    可以从clf.tree_featurethreshold的值,以及叶子节点处的杂质clf.tree_.impurityclf.tree_.value的值中获取规则。

    for index, branch in enumerate(all_branches):
        leaf_index = branch[-1]
        print(f'Branch: {index}, Path: {branch}')
        print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
        print(f'Value: {value[leaf_index]}')
        print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
        print(f"---------------------------------------------------------------------------------------\n")
    >>>
    Branch: 0, Path: [0, 1]
    Gin 0.0 at leaf node 1
    Value: [[37.  0.  0.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    Branch: 1, Path: [0, 2, 3, 5]
    Gin 0.0 at leaf node 5
    Value: [[ 0. 32.  0.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    Branch: 2, Path: [0, 2, 3, 6, 7]
    Gin 0.0 at leaf node 7
    Value: [[0. 0. 3.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    Branch: 3, Path: [0, 2, 3, 6, 8]
    Gin 0.0 at leaf node 8
    Value: [[0. 1. 0.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    Branch: 4, Path: [0, 2, 4, 9]
    Gin 0.375 at leaf node 9
    Value: [[0. 1. 3.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    Branch: 5, Path: [0, 2, 4, 10]
    Gin 0.0 at leaf node 10
    Value: [[ 0.  0. 35.]]
    Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
    ---------------------------------------------------------------------------------------
    
    

    【讨论】:

    • 感谢您的回答,米格尔。它在运行retrieve_branches 时给了我TypeError: 'int' object is not iterable。还注意到参数顺序与函数中定义的顺序不同,您添加了一个is_leaves 参数,我不知道它来自哪里。而函数内部的变量path,对应于参数path_list?
    • 还注意到,在获取所有分支的规则时(在步骤 3 中),不等式始终为 '
    • @xerac 抱歉,函数有一些错误,请检查编辑,不等式的符号可以根据阈值改变。
    猜你喜欢
    • 2021-12-25
    • 2020-07-25
    • 2016-06-17
    • 2018-12-09
    • 1970-01-01
    • 2021-06-17
    • 2018-10-16
    • 2019-01-28
    • 2016-06-17
    相关资源
    最近更新 更多