【发布时间】:2023-03-19 16:05:02
【问题描述】:
我正在寻找一种将使用 scikit sklearn 训练的决策树转换为决策表的方法。
我想知道如何解析决策树结构以找到每一步做出的决策。
然后我想知道如何构造这个表。
你知道方法或有想法吗?
【问题讨论】:
标签: machine-learning datatable scikit-learn decision-tree converters
我正在寻找一种将使用 scikit sklearn 训练的决策树转换为决策表的方法。
我想知道如何解析决策树结构以找到每一步做出的决策。
然后我想知道如何构造这个表。
你知道方法或有想法吗?
【问题讨论】:
标签: machine-learning datatable scikit-learn decision-tree converters
以other answer here 为基础。以下以相同的方式遍历树,但生成一个 pandas 数据帧作为输出。
import sklearn
import pandas as pd
def tree_to_df(reg_tree, feature_names):
tree_ = reg_tree.tree_
feature_name = [
feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
def recurse(node, row, ret):
if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
# Add rule to row and search left branch
row[-1].append(name + " <= " + str(threshold))
recurse(tree_.children_left[node], row, ret)
# Add rule to row and search right branch
row[-1].append(name + " > " + str(threshold))
recurse(tree_.children_right[node], row, ret)
else:
# Add output rules and start a new row
label = tree_.value[node]
ret.append("return " + str(label[0][0]))
row.append([])
# Initialize
rules = [[]]
vals = []
# Call recursive function with initial values
recurse(0, rules, vals)
# Convert to table and output
df = pd.DataFrame(rules).dropna(how='all')
df['Return'] = pd.Series(values)
return df
【讨论】:
以下是将决策树转换为“python”代码的示例代码。您可以轻松地调整它来制作表格。
你需要做的就是创建一个全局变量,它是一个表格,叶子的数量乘以特征(或特征类别)的数量并递归地填充它
def tree_to_code(tree, feature_names, classes_names):
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print( "def tree(" + ", ".join(feature_names) + "):" )
def recurse(node, depth):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
print( indent + "if " + name + " <= " + str(threshold)+ ":" )
recurse(tree_.children_left[node], depth + 1)
print( indent + "else: # if " + name + "<=" + str(threshold) )
recurse(tree_.children_right[node], depth + 1)
else:
impurity = tree.tree_.impurity[node]
dico, label = cast_value_to_dico( tree_.value[node], classes_names )
print( indent + "# impurity=" + str(impurity) + " count_max=" + str(dico[label]) )
print( indent + "return " + str(label) )
recurse(0, 1)
【讨论】:
从 sklearn.datasets 导入 load_iris
从 sklearn.tree 导入决策树分类器
从 sklearn.tree 导入导出文本
iris = load_iris()
X = iris['data']
y = iris['target']
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(X, y)
r = export_text(decision_tree, feature_names=iris['feature_names'])
打印(r)
listt= [r]
打印(列表)
#########OUTPUT###########################
|--- 花瓣宽度 (cm)
| |--- class: 0
|--- 花瓣宽度(cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- 花瓣宽度 (cm) > 1.75
| | |--- 等级:2
【讨论】: