CART算法实现
1.python3
|
本文将构建两种树,一种回归树(regression tree),其每个节点包含单个值;第二种是模型树(model tree),其每个叶节点包含一个线性方程。 createTree()伪代码: 找到最佳的待切分特征: 如果该节点不能再分,将该节点存为叶节点 执行二元切分 在右子树调用createTree()方法 在左子树调用createTree()方法
from numpy import * #dataSet是数据集合,feature是待切分的特征,value是该特征的某个值 对上述代码构造假数据进行测试: import regTrees [[ 1. 0. 0. 0.] [ 0. 1. 0. 0.] [ 0. 0. 1. 0.] [ 0. 0. 0. 1.]] ****** [[ 0.] [ 1.] [ 0.] [ 0.]] ****** [[ 0. 1. 0. 0.]] ****** [[ 1. 0. 0. 0.] [ 0. 0. 1. 0.] [ 0. 0. 0. 1.]] 将CART算法用于回归: 回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。 chooseBestSplit(): 功能:给定某个误差计算方法,该函数会找到数据集熵最佳的二元切分方式。此外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。所以该函数需要完成两个方面: 1.用最佳方式切分数据集 2.生成相应的叶节点 伪代码: 对每个特征: 对每个特征值: 将数据集切分成两份 计算切分的误差 如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差返回最佳切分的特征和阈值 def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
代码测试: from numpy import * |
图5-1 基于CART算法构建回归树的简单数据集
from numpy import * |
图5-2 用于测试回归树的分段常数数据集
|
该树包含5个叶节点。上述过程已经完成回归树的构建,但是需要某种措施来检查构建的过程是否恰当。 from numpy import * |
图5-3 将图5-1的轴扩大100倍后的新数据集
|
图5-3看上去和图5-1分长相思但是y轴的数量级是5-1的100倍这里的新树有很多叶节点,而5-1只有2个。这是因为停止条件tolS对误差的数量级十分敏感。 下面讨论后剪枝,即利用测试集来对树进行剪枝,由于不需要用户指定参数,后剪枝是一个更理想化的方法。 后剪枝方法需要将数据集分类成测试集和训练集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集判断将这些叶节点合并是否能降低测试误差,如果是的话就合并。 Prune()伪代码如下: 基于已有的树切分测试数据: 如果存在任一字集是一棵树,则在该子集递归剪枝过程 计算将当前两个叶子节点和并后的误差 计算不合并的误差 如果合并会降低误差的话,就将叶节点合并
def isTree(obj): 代码测试: from numpy import * |
图5-4 用来测试某行书构建函数的分段线性数据
|
模型树的叶节点生成函数 #模型树 {'spInd': 0, 'spVal': 0.285477, 'left': matrix([[ 1.69855694e-03], [ 1.19647739e+01]]), 'right': matrix([[ 3.46877936], [ 1.18521743]])} |
2.R语言实现
|
#查看所有连续变量的相关性,所有分类变量的卡方值 > idx.num <- which(sapply(algae,is.numeric)) > idx.num mxPH mnO2 Cl NO3 NH4 oPO4 PO4 Chla a1 a2 a3 a4 a5 a6 a7 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 > correlation <- cor(algae$a1,algae[,idx.num],use = "pairwise.complete.obs") > correlation mxPH mnO2 Cl NO3 NH4 oPO4 [1,] -0.2651354 0.2873732 -0.3711709 -0.2412111 -0.132656 -0.4173576 PO4 Chla a1 a2 a3 a4 a5 [1,] -0.4864228 -0.2779866 1 -0.2937678 -0.1465666 -0.03795656 -0.2915492 a6 a7 [1,] -0.2734283 -0.2129063 > correlation <- abs(correlation) > correlation <- correlation[,order(correlation,decreasing = T)] > correlation a1 PO4 oPO4 Cl a2 a5 mnO2 1.00000000 0.48642276 0.41735761 0.37117086 0.29376781 0.29154923 0.28737317 Chla a6 mxPH NO3 a7 a3 NH4 0.27798661 0.27342831 0.26513541 0.24121109 0.21290633 0.14656656 0.13265601 a4 0.03795656 #查看分类变量的卡方值 > idx.factor <- which(sapply(algae,is.factor)) > idx.factor season size speed 1 2 3 > class(idx.factor) [1] "integer" > algae[,idx.factor] season size speed 1 winter small medium 2 spring small medium 3 autumn small medium 4 spring small medium 5 autumn small medium 6 winter small high 7 summer small high 8 autumn small high 9 winter small medium 10 winter small high ……
|