考虑到来自 sklearn docs 的 irist 数据集示例,我们遵循以下步骤。
1.生成示例决策树
代码取自docs
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import numpy as np
iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=6, random_state=0)
clf.fit(X_train, y_train)
2。检索分支路径
首先我们从树中检索以下值
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
impurity = clf.tree_.impurity
value = clf.tree_.value
在retrieve_branches 内部,我们计算叶节点并从原始节点向下迭代到叶节点,当我们到达叶节点时,我们使用yield 语句返回分支路径。
def retrieve_branches(number_nodes, children_left_list, children_right_list):
"""Retrieve decision tree branches"""
# Calculate if a node is a leaf
is_leaves_list = [(False if cl != cr else True) for cl, cr in zip(children_left_list, children_right_list)]
# Store the branches paths
paths = []
for i in range(number_nodes):
if is_leaves_list[i]:
# Search leaf node in previous paths
end_node = [path[-1] for path in paths]
# If it is a leave node yield the path
if i in end_node:
output = paths.pop(np.argwhere(i == np.array(end_node))[0][0])
yield output
else:
# Origin and end nodes
origin, end_l, end_r = i, children_left_list[i], children_right_list[i]
# Iterate over previous paths to add nodes
for index, path in enumerate(paths):
if origin == path[-1]:
paths[index] = path + [end_l]
paths.append(path + [end_r])
# Initialize path in first iteration
if i == 0:
paths.append([i, children_left[i]])
paths.append([i, children_right[i]])
要调用retrieve_branches,只需传递n_nodes、children_left 和children_right 以及一个将存储和更新分支路径的空列表。最终显示如下。
all_branches = list(retrieve_branches(n_nodes, children_left, children_right))
all_branches
>>>
[[0, 1],
[0, 2, 3, 5],
[0, 2, 3, 6, 7],
[0, 2, 3, 6, 8],
[0, 2, 4, 9],
[0, 2, 4, 10]]
3。按分支的路径、价值和基尼
可以从clf.tree_的feature和threshold的值,以及叶子节点处的杂质clf.tree_.impurity和clf.tree_.value的值中获取规则。
for index, branch in enumerate(all_branches):
leaf_index = branch[-1]
print(f'Branch: {index}, Path: {branch}')
print(f'Gin {impurity[leaf_index]} at leaf node {branch[-1]}')
print(f'Value: {value[leaf_index]}')
print(f"Decision Rules: {[f'if X[:, {feature[elem]}] <= {threshold[elem]}' for elem in branch]}")
print(f"---------------------------------------------------------------------------------------\n")
>>>
Branch: 0, Path: [0, 1]
Gin 0.0 at leaf node 1
Value: [[37. 0. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 1, Path: [0, 2, 3, 5]
Gin 0.0 at leaf node 5
Value: [[ 0. 32. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 2, Path: [0, 2, 3, 6, 7]
Gin 0.0 at leaf node 7
Value: [[0. 0. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 3, Path: [0, 2, 3, 6, 8]
Gin 0.0 at leaf node 8
Value: [[0. 1. 0.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 3] <= 1.6500000357627869', 'if X[:, 1] <= 3.100000023841858', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 4, Path: [0, 2, 4, 9]
Gin 0.375 at leaf node 9
Value: [[0. 1. 3.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------
Branch: 5, Path: [0, 2, 4, 10]
Gin 0.0 at leaf node 10
Value: [[ 0. 0. 35.]]
Decision Rules: ['if X[:, 3] <= 0.800000011920929', 'if X[:, 2] <= 4.950000047683716', 'if X[:, 2] <= 5.049999952316284', 'if X[:, -2] <= -2.0']
---------------------------------------------------------------------------------------