【问题标题】:Random forest visualization in pythonpython中的随机森林可视化
【发布时间】:2020-08-08 23:23:51
【问题描述】:

TypeError                                 Traceback (most recent call last)
<ipython-input-25-e7781de34abc> in <module>
      3                feature_names = fn,
      4                class_names=cn,
----> 5                filled = False);
      6 fig.savefig('rf_individualtree.png')

~/opt/anaconda3/lib/python3.7/site-packages/sklearn/tree/_export.py in plot_tree(decision_tree, max_depth, feature_names, class_names, label, filled, impurity, node_ids, proportion, rotate, rounded, precision, ax, fontsize)
    174         proportion=proportion, rotate=rotate, rounded=rounded,
    175         precision=precision, fontsize=fontsize)
--> 176     return exporter.export(decision_tree, ax=ax)
    177 
    178 

~/opt/anaconda3/lib/python3.7/site-packages/sklearn/tree/_export.py in export(self, decision_tree, ax)
    565         ax.set_axis_off()
    566         my_tree = self._make_tree(0, decision_tree.tree_,
--> 567                                   decision_tree.criterion)
    568         draw_tree = buchheim(my_tree)
    569 

~/opt/anaconda3/lib/python3.7/site-packages/sklearn/tree/_export.py in _make_tree(self, node_id, et, criterion, depth)
    546         # traverses _tree.Tree recursively, builds intermediate
    547         # "_reingold_tilford.Tree" object
--> 548         name = self.node_to_str(et, node_id, criterion=criterion)
    549         if (et.children_left[node_id] != _tree.TREE_LEAF
    550                 and (self.max_depth is None or depth <= self.max_depth)):

~/opt/anaconda3/lib/python3.7/site-packages/sklearn/tree/_export.py in node_to_str(self, tree, node_id, criterion)
    340                                           np.argmax(value),
    341                                           characters[2])
--> 342             node_string += class_name
    343 
    344         # Clean up any trailing newlines

TypeError: can only concatenate str (not "numpy.int64") to str
import matplotlib.pyplot as plt
import numpy as np
import PIL
import pydot
import warnings
from sklearn import tree
from glob import glob
from IPython.display import display, Image
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import export_graphviz
%matplotlib inline
warnings.filterwarnings("ignore")

df = pd.read_csv('heart.csv')

df.head()

x = df.loc[:, df.columns != 'target']
y = df.loc[:, 'target'].values

from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x,y,test_size = 0.2,random_state=0)


from sklearn.preprocessing import StandardScaler

sc = StandardScaler()
x_train = sc.fit_transform(x_train)
x_test = sc.transform(x_test)


rf = RandomForestClassifier(n_estimators=100,
                            random_state=0)
rf.fit(x_train, y_train)

fn=features = list(df.columns[1:])
cn=df.target


fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=800)
tree.plot_tree(rf.estimators_[0],
               feature_names = fn, 
               class_names=cn,
               filled = False);
fig.savefig('rf_individualtree.png')

我正在按照这个结构来可视化我的随机森林图

https://i.stack.imgur.com/MkH71.png

所以当我编写高亮代码时,我得到了错误 TypeError:只能将str(不是“numpy.int64”)连接到str

我使用来自 kaggle 的数据集“https://www.kaggle.com/ronitf/heart-disease-uci

如果你能帮助我将不胜感激

【问题讨论】:

  • 这可能是因为在某些时候(tree.plot_tree)代码要求字符串,但你给它整数。帮我一个忙,在你的代码中包含在 plot_tree type(fn), type(cn), type(fn[0]), type(cn[0]) 之前,看看它们中的任何一个不是字符串还是列表。如果是这种情况,您应该写 fn=[str(x) for x in fn], fc=[str(x) for x in fc] 而不是 type。
  • 另外,如果您编写代码而不是发布图片,人们可以比复制图片内容更快地测试代码。错误也是如此,如果你写了它所说的所有内容,它可以更容易地找到给你带来问题的那一行。
  • 您好,感谢您的回复,我已经上传了整个代码
  • 我刚刚在html控制台上传了它,抱歉我是新来的,所以对功能不太熟悉,而且好像我也不能使用很多功能,希望你能帮助我跨度>
  • 对不起,如果我听起来太苛刻了,解决方案正在路上。

标签: python-3.x visualization random-forest


【解决方案1】:

plot_tree 中的参数 class_name 需要一个字符串列表,但在您的代码中 cn 是一个整数列表(准确地说是 numpy.int64)。您需要做的就是将该列表转换为字符串并解决问题。

#some code before
fn=features = list(df.columns[1:])
cn=df.target

#conversion from list of numpy.int64 to list of string
cn=[str(x) for x in cn]

fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=800)
tree.plot_tree(...
#some code after

【讨论】:

    猜你喜欢
    • 2019-08-11
    • 2016-05-15
    • 2018-04-10
    • 2016-09-09
    • 2021-05-29
    • 2013-06-26
    • 2014-01-23
    • 2021-12-25
    • 2015-02-12
    相关资源
    最近更新 更多