【问题标题】:Plotting a decision tree manually with pyplot使用 pyplot 手动绘制决策树
【发布时间】:2021-10-01 16:45:48
【问题描述】:

我是matplotlib 的新手,我正在尝试绘制我从头开始构建的决策树(不是使用sklearn)所以它基本上是一个Node 对象与leftright 和其他递归构建的标识变量。

这是我的程序:

def plot_tree(node, x_axis=0, y_axis=10, space=5):
    if node.label is not None:
        ax.text(x_axis, y_axis, node.label[0],
                bbox=dict(boxstyle='round', facecolor='green', edgecolor='g'), ha='center', va='center')

    else:
        ax.text(x_axis, y_axis, f'{node.value:.2f}\nidx:{node.feature_idx}',
                     bbox=dict(boxstyle='round', facecolor='red', edgecolor='r'), ha='center', va='center')

        # x2, y2, w2, h2 = t2.get_tightbbox(fig.canvas.get_renderer()).bounds
        # plt.annotate(' ', xy=(x2 + w2, y2 + h2), xytext=(x_axis, y_axis), xycoords='figure points',
        #              arrowprops=dict(arrowstyle="<|-,head_length=1,head_width=0.5", lw=2, color='b'))

        plot_tree(node.left, x_axis + space, y_axis + space)
        plot_tree(node.right, x_axis + space, y_axis - space)


if __name__ == '__main__':
    node = root.load_tree()
    fig, ax = plt.subplots(1, 1)
    ax.axis('off')
    ax.set_aspect('equal')
    ax.autoscale_view()
    ax.set_xlim(0, 30)
    ax.set_ylim(-10, 30)
    plt.tick_params(axis='both', labelsize=0, length=0)
    plot_tree(node)

我的结果:

我知道 y 轴因为 y_axis + spacey_axis - space 而发生碰撞,但我真的不知道如何让它的间距保持对称而不是这样。 正如您看到的那样,箭头被注释掉了,因为它们本身就是一团糟,这个库非常丰富,弄清楚它有点让人不知所措。

编辑:这是树的打印表示:

 split is at feature:  27  and value  0.14235  and depth is:  1
     split is at feature:  20  and value  17.615000000000002  and depth is:  2
         label is:  B and depth is:  3
         split is at feature:  8  and value  0.15165  and depth is:  3
             label is:  B and depth is:  4
             label is:  M and depth is:  4
     split is at feature:  13  and value  13.93  and depth is:  2
         label is:  B and depth is:  3
         label is:  M and depth is:  3

【问题讨论】:

    标签: python python-3.x matplotlib decision-tree


    【解决方案1】:

    您最好使用 Graphviz,因为它会为您处理间距。下载Graphviz 和它的Python bindings,然后你可以很容易地像这样渲染图形:

    dot = graphviz.Digraph(comment="A graph", format="svg")
    dot.node('A', 'King Arthur')
    dot.node('B', 'Sir Bedevere the Wise')
    dot.node('C', 'Sir Lancelot the Brave')
    dot.edge('A', 'B')
    dot.edge('A', 'C')
    dot.render('digraph.gv', view=True)  
    

    【讨论】:

    • 是的,我知道Graphviz 是一个选项,但我的任务是使用pyplot 来完成它
    猜你喜欢
    • 2014-08-30
    • 2016-11-09
    • 1970-01-01
    • 2013-05-30
    • 2017-02-21
    • 2019-06-14
    • 2017-01-30
    • 2018-02-09
    • 1970-01-01
    相关资源
    最近更新 更多