【问题标题】:Visualizing a decision tree ( example from scikit-learn )可视化决策树(来自 scikit-learn 的示例)
【发布时间】:2012-05-21 03:40:22
【问题描述】:

我是使用 sciki-learn 的菜鸟,所以请多多包涵。

我正在查看示例: http://scikit-learn.org/stable/modules/tree.html#tree

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>> iris = load_iris()
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(iris.data, iris.target)
>>> from StringIO import StringIO
>>> out = StringIO()
>>> out = tree.export_graphviz(clf, out_file=out)

显然,graphiz 文件已经可以使用了。

但是如何使用 graphiz 文件绘制树呢? (该示例没有详细说明如何绘制树)。

非常欢迎示例代码和提示!

谢谢!


更新

我使用的是 ubuntu 12.04,Python 2.7.3

【问题讨论】:

  • 0.21 版的 Scikit-learn 具有 plot_tree 方法,它比导出到 graphviz 更容易使用。无论如何,还有一个非常不错的包 dtreeviz。下面是sklearn树的可视化方法对比:blog post link.

标签: machine-learning python-2.7 scipy scikit-learn


【解决方案1】:

您运行哪个操作系统?你有graphviz 安装吗?

在您的示例中,StringIO() 对象包含 graphviz 数据,这是检查数据的一种方法:

...
>>> print out.getvalue()

digraph Tree {
0 [label="X[2] <= 2.4500\nerror = 0.666667\nsamples = 150\nvalue = [ 50.  50.  50.]", shape="box"] ;
1 [label="error = 0.0000\nsamples = 50\nvalue = [ 50.   0.   0.]", shape="box"] ;
0 -> 1 ;
2 [label="X[3] <= 1.7500\nerror = 0.5\nsamples = 100\nvalue = [  0.  50.  50.]", shape="box"] ;
0 -> 2 ;
3 [label="X[2] <= 4.9500\nerror = 0.168038\nsamples = 54\nvalue = [  0.  49.   5.]", shape="box"] ;
2 -> 3 ;
4 [label="X[3] <= 1.6500\nerror = 0.0407986\nsamples = 48\nvalue = [  0.  47.   1.]", shape="box"] ;
3 -> 4 ;
5 [label="error = 0.0000\nsamples = 47\nvalue = [  0.  47.   0.]", shape="box"] ;
4 -> 5 ;
6 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
4 -> 6 ;
7 [label="X[3] <= 1.5500\nerror = 0.444444\nsamples = 6\nvalue = [ 0.  2.  4.]", shape="box"] ;
3 -> 7 ;
8 [label="error = 0.0000\nsamples = 3\nvalue = [ 0.  0.  3.]", shape="box"] ;
7 -> 8 ;
9 [label="X[0] <= 6.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  2.  1.]", shape="box"] ;
7 -> 9 ;
10 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  2.  0.]", shape="box"] ;
9 -> 10 ;
11 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  0.  1.]", shape="box"] ;
9 -> 11 ;
12 [label="X[2] <= 4.8500\nerror = 0.0425331\nsamples = 46\nvalue = [  0.   1.  45.]", shape="box"] ;
2 -> 12 ;
13 [label="X[0] <= 5.9500\nerror = 0.444444\nsamples = 3\nvalue = [ 0.  1.  2.]", shape="box"] ;
12 -> 13 ;
14 [label="error = 0.0000\nsamples = 1\nvalue = [ 0.  1.  0.]", shape="box"] ;
13 -> 14 ;
15 [label="error = 0.0000\nsamples = 2\nvalue = [ 0.  0.  2.]", shape="box"] ;
13 -> 15 ;
16 [label="error = 0.0000\nsamples = 43\nvalue = [  0.   0.  43.]", shape="box"] ;
12 -> 16 ;
}

您可以将其写为.dot file 并生成图像输出,如您链接的源代码所示:

$ dot -Tpng tree.dot -o tree.png(PNG格式输出)

【讨论】:

  • 您好,谢谢!我使用的是 Ubuntu 12.04,Python 版本 2.7.3。我想知道我是否可以在 python 脚本中而不是在命令行中做到这一点?
  • 当然,只需获取可用的Python bindings to graphviz 之一,您应该可以在 python shell 中进行操作
  • 有什么方法可以在python3中完成这个任务吗?
【解决方案2】:

你很亲密!做吧:

graph_from_dot_data(out.getvalue()).write_pdf("somefile.pdf")

【讨论】:

  • 这只有在#classes 足够小以至于文本中的 nvalue 数组不会跨行断开时才有效...在这种情况下,我不得不手动搜索/替换 \n 为 '' (当然,保留合法的)......有点痛苦。同样适用于单热编码标签...它们会立即抛出错误。
猜你喜欢
  • 2015-03-05
  • 2020-09-08
  • 2017-07-21
  • 2017-02-23
  • 2020-04-05
  • 2017-03-26
  • 1970-01-01
  • 2019-12-26
相关资源
最近更新 更多