【问题标题】:Why using random forest to make sure my decision tree model doesn't overfit?为什么要使用随机森林来确保我的决策树模型不会过拟合?
【发布时间】:2021-03-11 11:37:42
【问题描述】:

我的代码:

# Create Decision Tree classifer object
clf = DecisionTreeClassifier()

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

# Using random forest to make sure my model doesn't overfit

from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators = 20) #n_esitmators value can be changed according to need
clf = clf.fit(ft,pima['brand'])

我想知道关于上述随机森林分类器在代码中的应用的最佳解释。这次使用这个随机森林分类器的原因是什么?

【问题讨论】:

  • 这次使用这个随机森林分类器的原因是什么,除非我们知道数据,否则我们无法说出原因。
  • 你在问为什么随机森林被用于单个决策树?这真的不是 Stack Overflow 的主题问题,而是针对特定的编程问题。这更像是一个理论机器学习问题,可能更适合Cross Validated,尽管那里可能已经有很多关于该主题的答案。

标签: python machine-learning random-forest


【解决方案1】:

哎呀!你的问题到底是什么?这里的最终游戏是什么?基本上,随机森林算法由决策树组成。单个决策树对数据变化非常敏感。它很容易过度拟合数据中的噪声。只有一棵树的随机森林也会过度拟合数据,因为它与单棵决策树相同。

当我们将树添加到随机森林时,过度拟合的趋势应该会降低(感谢 bagging 和随机特征选择)。但是,泛化误差不会变为零。添加更多树后,泛化误差的方差将接近于零,但偏差不会!

运行下面的例子:

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import MiniBatchKMeans 
from scipy.cluster.vq import kmeans,vq
import sklearn.model_selection as model_selection
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

# Import CSV mtcars
data = pd.read_csv('https://gist.githubusercontent.com/ZeccaLehn/4e06d2575eb9589dbe8c365d61cb056c/raw/64f1660f38ef523b2a1a13be77b002b98665cdfe/mtcars.csv')
# Edit element of column header
data.rename(columns={'Unnamed: 0':'brand'}, inplace=True)

X1= data.iloc[:,1:12]
Y1= data.iloc[:,-1]

#lets try to plot Decision tree to find the feature importance
from sklearn.tree import DecisionTreeClassifier
tree= DecisionTreeClassifier(criterion='entropy', random_state=1)
tree.fit(X1, Y1)

imp= pd.DataFrame(index=X1.columns, data=tree.feature_importances_, columns=['Imp'] )
imp.sort_values(by='Imp', ascending=False)

sns.barplot(x=imp.index.tolist(), y=imp.values.ravel(), palette='coolwarm')

X=data[['cyl','disp','hp','drat','wt','qsec','vs','am','gear','carb']]
y=data['mpg']

# split to train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 0)


rf = RandomForestRegressor(n_estimators=50)
rf.fit(X_train, y_train)
y_train_predicted = rf.predict(X_train)
y_test_predicted_full_trees = rf.predict(X_test)
mse_train = mean_squared_error(y_train, y_train_predicted)
mse_test = mean_squared_error(y_test, y_test_predicted_full_trees)
print("RF with full trees, Train MSE: {} Test MSE: {}".format(mse_train, mse_test))


rf = RandomForestRegressor(n_estimators=50, min_samples_leaf=25)
rf.fit(X_train, y_train)
y_train_predicted = rf.predict(X_train)
y_test_predicted_pruned_trees = rf.predict(X_test)
mse_train = mean_squared_error(y_train, y_train_predicted)
mse_test = mean_squared_error(y_test, y_test_predicted_pruned_trees)
print("RF with pruned trees, Train MSE: {} Test MSE: {}".format(mse_train, mse_test))


rf = RandomForestRegressor(n_estimators=1)
for iter in range(50):
    rf.fit(X_train, y_train)
    y_train_predicted = rf.predict(X_train)
    y_test_predicted = rf.predict(X_test)
    mse_train = mean_squared_error(y_train, y_train_predicted)
    mse_test = mean_squared_error(y_test, y_test_predicted)
    print("Iteration: {} Train mse: {} Test mse: {}".format(iter, mse_train, mse_test))
    rf.n_estimators += 1


import graphviz
from sklearn import datasets
from sklearn.tree import DecisionTreeRegressor
from sklearn import tree

# To keep the size of the tree small, I set max_depth = 3.
# Fit the regressor, set max_depth = 3
regr = DecisionTreeRegressor(max_depth=3, random_state=1234)
model = regr.fit(X, y)

# 1
text_representation = tree.export_text(regr)
print(text_representation)


fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(regr, feature_names=X.columns, filled=True)

# 2
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(regr, 
                   feature_names=X.columns,  
                   filled=True)

【讨论】:

    猜你喜欢
    • 2018-04-04
    • 2017-08-11
    • 2016-03-01
    • 2018-12-01
    • 2017-12-20
    • 2017-03-15
    • 2019-11-22
    • 2021-03-24
    • 2018-06-22
    相关资源
    最近更新 更多