【发布时间】:2015-02-28 17:52:00
【问题描述】:
我想获取使用 sklearn.tree 进行预测的节点的所有信息。
例如:
from sklearn.datasets import load_iris
nfrom sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
现在我们可以使用以下方法预测类:
clf.predict(iris.data[0, :])
如何获取进行预测的叶子节点以及叶子中存储的信息?
我知道上面示例中树的图形表示如下:
http://scikit-learn.org/stable/modules/tree.html#tree-classification
所以我知道输入iris.data[0, :](第一个左孩子)对应的节点有如下统计:
- 错误=0
- 样本=50
- 值 = [50 0 0]
是否可以在不打印树的情况下自动获取输出节点和(以上)信息?根据我目前的理解,关键是获取叶节点的 ID 进行预测,然后将相关统计信息包含在 clf.tree_.value[ID] 和clf.tree_.n_samples[ID].
谢谢
【问题讨论】:
标签: python scikit-learn regression