【问题标题】:Decision Trees with SKlearn and Visualization带有 SKlearn 和可视化的决策树
【发布时间】:2017-05-01 04:25:42
【问题描述】:

研究 Kaggle Titanic 数据集。我试图更好地理解决策树,我已经使用线性回归很好,但从未使用过决策树。我正在尝试在 python 中为我的树创建一个可视化。但是有些东西不起作用。在下面检查我的代码。

import pandas as pd
from sklearn import tree
from sklearn.datasets import load_iris
import numpy as np


train_file='.......\RUN.csv'
train=pd.read_csv(train_file)

#impute number values and missing values
train["Sex"][train["Sex"] == "male"] = 0
train["Sex"][train["Sex"] == "female"] = 1
train["Embarked"] = train["Embarked"].fillna("S")
train["Embarked"][train["Embarked"] == "S"]= 0
train["Embarked"][train["Embarked"] == "C"]= 1
train["Embarked"][train["Embarked"] == "Q"]= 2
train["Age"] = train["Age"].fillna(train["Age"].median())
train["Pclass"] = train["Pclass"].fillna(train["Pclass"].median())
train["Fare"] = train["Fare"].fillna(train["Fare"].median())

target = train["Survived"].values
features_one = train[["Pclass", "Sex", "Age", "Fare","SibSp","Parch","Embarked"]].values


# Fit your first decision tree: my_tree_one
my_tree_one = tree.DecisionTreeClassifier(max_depth = 10, min_samples_split = 5, random_state = 1)

iris=load_iris()

my_tree_one = my_tree_one.fit(features_one, target)

tree.export_graphviz(my_tree_one, out_file='tree.dot')

我如何真正看到决策树?试图将其形象化。

帮助表示赞赏!

【问题讨论】:

    标签: python tree visualization decision-tree kaggle


    【解决方案1】:

    我使用条形图进行了可视化。第一个图表示类的分布。第一个标题代表第一个拆分标准。满足此标准的所有数据都会产生左侧的底层子图。如果没有,正确的情节就是结果。因此,所有标题都表示下一次拆分的拆分标准。

    百分比是来自初始分布的值。因此,通过查看百分比,我们可以很容易地了解初始数据量在经过几次拆分后还剩下多少。

    注意,如果你设置 max_depth 高,这将需要很多子图(max_depth,2^depth)

    Tree visualization using bar plots

    代码:

    def give_nodes(nodes,amount_of_branches,left,right):
        amount_of_branches*=2
        nodes_splits=[]
        for node in nodes:
            nodes_splits.append(left[node])
            nodes_splits.append(right[node])
        return (nodes_splits,amount_of_branches)
    
    def plot_tree(tree, feature_names):
        from matplotlib import gridspec 
        import matplotlib.pyplot as plt
        from matplotlib import rc
        import pylab
    
        color = plt.cm.coolwarm(np.linspace(1,0,len(feature_names)))
    
        plt.rc('text', usetex=True)
        plt.rc('font', family='sans-serif')
        plt.rc('font', size=14)
    
        params = {'legend.fontsize': 20,
                 'axes.labelsize': 20,
                 'axes.titlesize':25,
                 'xtick.labelsize':20,
                 'ytick.labelsize':20}
        plt.rcParams.update(params)
    
        max_depth=tree.max_depth
        left      = tree.tree_.children_left
        right     = tree.tree_.children_right
        threshold = tree.tree_.threshold
        features  = [feature_names[i] for i in tree.tree_.feature]
        value = tree.tree_.value
    
        fig = plt.figure(figsize=(3*2**max_depth,2*2**max_depth))
        gs = gridspec.GridSpec(max_depth, 2**max_depth)
        plt.subplots_adjust(hspace = 0.6, wspace=0.8)
    
        # All data
        amount_of_branches=1
        nodes=[0]
        normalize=np.sum(value[0][0])
    
        for i,node in enumerate(nodes):
            ax=fig.add_subplot(gs[0,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
            ax.set_title( features[node]+"$<= "+str(threshold[node])+"$")
            if( i==0): ax.set_ylabel(r'$\%$')
            ind=np.arange(1,len(value[node][0])+1,1)
            width=0.2
            bars= (np.array(value[node][0])/normalize)*100
            plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
            plt.xticks(ind, [int(i) for i in ind-1])
            pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))
    
        # Splits
        for j in range(1,max_depth):
            nodes,amount_of_branches=give_nodes(nodes,amount_of_branches,left,right)
            for i,node in enumerate(nodes):
                ax=fig.add_subplot(gs[j,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches])
                ax.set_title( features[node]+"$<= "+str(threshold[node])+"$")
                if( i==0): ax.set_ylabel(r'$\%$')
                ind=np.arange(1,len(value[node][0])+1,1)
                width=0.2
                bars= (np.array(value[node][0])/normalize)*100
                plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0)
                plt.xticks(ind, [int(i) for i in ind-1])
                pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2))
    
    
        plt.tight_layout()
        return fig
    

    例子:

    X=[]
    Y=[]
    amount_of_labels=5
    feature_names=[ '$x_1$','$x_2$','$x_3$','$x_4$','$x_5$']
    for i in range(200):
        X.append([np.random.normal(),np.random.randint(0,100),np.random.uniform(200,500) ])
        Y.append(np.random.randint(0,amount_of_labels))
    
    clf = tree.DecisionTreeClassifier(criterion='entropy',max_depth=4)
    clf = clf.fit(X,Y )
    fig=plot_tree(clf, feature_names)
    

    【讨论】:

      【解决方案2】:

      来自维基百科:

      DOT 语言定义了一个图形,但不提供渲染图形的工具。有几个程序可用于渲染、查看和操作 DOT 语言中的图形:

      Graphviz - 一系列用于操作和呈现图形的库和实用程序

      Canviz - 一个用于渲染点文件的 JavaScript 库。

      Viz.js - 一个简单的 Graphviz JavaScript 客户端

      Grappa - Graphviz 到 Java 的部分移植。[4][5]

      Beluging - 基于 Python 和 Google Cloud 的 DOT 和 Beluga 扩展查看器。 [1]

      Tulip 可以导入点文件进行分析

      OmniGraffle 可以导入 DOT 的子集,生成可编辑的文档。 (但是,结果无法导出回 DOT。)

      ZGRViewer,一个 GraphViz/DOT 查看器链接

      VizierFX,一个 Flex 图形渲染库链接

      Gephi - 适用于各种网络和复杂系统、动态和层次图的交互式可视化和探索平台

      因此,这些程序中的任何一个都能够可视化您的树。

      【讨论】:

      • 我已经在使用 graphviz,但我无法让它显示为图像。它只是将其写入 .dot 文件。我尝试将 ti 更改为 pdf,但似乎无法正常工作。
      • 我相信这应该只是写.dot文件。然后,您必须使用列出的应用程序之一来查看 .dot 文件。我个人喜欢 Gephi。
      【解决方案3】:

      您检查了吗:http://scikit-learn.org/stable/modules/tree.html 提到如何将树绘制为 PNG 图像:

       from IPython.display import Image 
       import pydotplus
       dot_data = tree.export_graphviz(my_tree_one, out_file='tree.dot')  
       graph = pydotplus.graph_from_dot_data(dot_data)  `
       Image(graph.create_png())
      

      【讨论】:

      • >>> 导入操作系统 >>> os.unlink('iris.dot')
      • 它说要这样做^。但是,这只会删除文件。有任何想法吗?我也没有pydotplus。我尝试使用 pip 下载它,但没有成功。
      • 我认为问题出在 Graphiz 上,你应该下载它:graphviz.org/Download..phpstackoverflow.com/questions/18438997/…。首先安装graphiz,然后安装pydot。或者使用linux。我稍后再谈。
      猜你喜欢
      • 2020-01-09
      • 2012-07-03
      • 2010-10-22
      • 2020-08-17
      • 2017-08-10
      • 2021-12-25
      • 2023-04-08
      • 2014-08-28
      相关资源
      最近更新 更多