【发布时间】:2022-01-15 20:45:46
【问题描述】:
我想找到在我的 xgboot 回归模型中实际分配的树深度数、叶子数。
【问题讨论】:
-
(如何知道在 XGBoost 中创建的树的数量)---> 这给出了森林中的树的数量。但我想知道每棵树的深度。
标签: python scikit-learn xgboost
我想找到在我的 xgboot 回归模型中实际分配的树深度数、叶子数。
【问题讨论】:
标签: python scikit-learn xgboost
要获取此信息,您需要取回 booster 对象,我假设您使用的是 scikit-learn 接口,例如使用具有 3 个估计器(树)和最大深度为 7 的模型:
import xgboost as xgb
from sklearn.datasets import make_classification
X,y = make_classification(random_state=99)
clf = xgb.XGBModel(objective='binary:logistic',n_estimators = 3,max_depth = 7)
clf.fit(X,y)
在这种情况下,我们拉出对象并将树信息转换为数据框:
booster = clf.get_booster()
tree_df = booster.trees_to_dataframe()
tree_df[tree_df['Tree'] == 0]
Tree Node ID Feature Split Yes No Missing Gain Cover
0 0 0 0-0 f11 -0.233068 0-1 0-2 0-1 48.161629 25.00
1 0 1 0-1 f1 -1.081945 0-3 0-4 0-3 0.054384 9.25
2 0 2 0-2 f14 0.480458 0-5 0-6 0-5 8.410727 15.75
3 0 3 0-3 Leaf NaN NaN NaN NaN -0.150000 1.00
4 0 4 0-4 Leaf NaN NaN NaN NaN -0.535135 8.25
5 0 5 0-5 f18 0.261421 0-7 0-8 0-7 5.638095 6.50
6 0 6 0-6 f9 -1.585489 0-9 0-10 0-9 0.727795 9.25
7 0 7 0-7 f18 -0.640538 0-11 0-12 0-11 4.342857 4.00
8 0 8 0-8 f0 0.072811 0-13 0-14 0-13 1.028571 2.50
9 0 9 0-9 Leaf NaN NaN NaN NaN 0.163636 1.75
10 0 10 0-10 Leaf NaN NaN NaN NaN 0.529412 7.50
11 0 11 0-11 Leaf NaN NaN NaN NaN -0.120000 1.50
12 0 12 0-12 Leaf NaN NaN NaN NaN 0.428571 2.50
13 0 13 0-13 Leaf NaN NaN NaN NaN -0.000000 1.00
14 0 14 0-14 Leaf NaN NaN NaN NaN -0.360000 1.50
可视化第一棵树。决策树的深度是从根到叶子的分裂次数,所以这棵树的深度为 4:
xgb.plotting.plot_tree(booster, num_trees=0)
也许有更好的解决方案,但很快我就使用了solution from this post,遍历 json 输出并计算每棵树的深度:
def item_generator(json_input, lookup_key):
if isinstance(json_input, dict):
for k, v in json_input.items():
if k == lookup_key:
yield v
else:
yield from item_generator(v, lookup_key)
elif isinstance(json_input, list):
for item in json_input:
yield from item_generator(item, lookup_key)
def tree_depth(json_text):
json_input = json.loads(json_text)
return max(list(item_generator(json_input, 'depth'))) + 1
[tree_depth(x) for x in booster.get_dump(dump_format = "json")]
[4, 4, 5]
【讨论】: