上一章提到的线性回归应该是使用最广泛的一种方法,但是这个方法时常会出现一些问题:比如我们需要对线性回归模型的随机干扰项施加一些假设,但现实中这些假设常常不能满足,此外当数据拥有众多特征且特征之间关系比较复杂时,用线性回归难以构建一个全局模型,最重要的,现实中满足线性关系的特征只占一小部分,大部分情况下特征之间满足的是非线性关系。
这时一种可行的方法就是将数据集切分成多份容易建模的数据,在切分的子集上建立回归模型,这里要用的结构就是回归树。


之前在决策树一章我们使用了ID3ID3算法构建决策树,这种算法每次选取当前最佳特征分割数据,并按照该特征所有可能取值来分割,这种方法切割过于迅速,容易产生过拟合,而且不能直接处理连续型数据,必须先把连续型数据分组为分类型数据。
我们对这一切分方法做一些改进,比如使用二元切分法,即每次把数据集分成两份,如果特征值大于给定值走左子树,否则走右子树。

具体地,在训练数据集所在的输入空间中,递归地将每个区域划分为两个子区域并决定每个子区域上的输出值:

  1. 选择最优切分变量jj和切分点ss,使得minj,s[minc1xiR1(j,s)(yic1)2+minc2xiR2(j,s)(yic2)2]\min_{j,s} \left[\min_{c_1} \sum_{x_i \in R_1(j,s)} (y_i-c_1)^2+\min_{c_2}\sum_{x_i\in R_2(j,s)}(y_i-c_2)^2 \right] 其中c1,c2c_1,c_2是输出在切分单元上的取值
  2. 用选定的对(j,s)(j,s)划分区域并确定相应的输出值,容易知道在每个区域上的输出值是该区域上所有输入实例对应输出的均值c^m=1NmxiRm(j,s)yi,xRm,m=1,2\hat{c}_m = \frac1{N_m}\sum_{x_i\in R_m(j,s)}y_i,x\in R_m,m=1,2
  3. 继续对两个子区域调用步骤1,2,直到满足停止条件
  4. 将输入空间划分为MM个区域R1,,RMR_1,\cdots,R_M,生成决策树:f(x)=m=1Mc^mI(xRm)f(x) = \sum_{m=1}^M\hat{c}_mI(x\in R_m)

在生成了回归树之后,我们还需要考虑对其进行剪枝,常用的CARTCART剪枝算法由两步构成:首先从生成算法产生的决策树T0T_0底端开始剪枝,直到T0T_0的根节点,得到子树序列{T0,T1,,Tn}\{T_0,T_1,\cdots, T_n\};然后通过交叉验证法在独立的验证数据集上对子树序列进行测试,选择最优子树。

  1. 剪枝。对整体树T0T_0内的任一内部节点tt,若已经剪枝,则以tt为单节点的损失函数为Cα(t)=C(t)+αC_{\alpha}(t) = C(t) + \alpha,而剪枝前,咦tt为根节点的子树TtT_t损失函数为Cα(Tt)=C(Tt)+αTtC_{\alpha}(T_t) = C(T_t) + \alpha|T_t|。我们在这里假设必然会发生剪枝,所以要求Cα(Tt)Cα(t)C_{\alpha}(T_t) \ge C_{\alpha}(t),因此只需取α1=C(t)C(Tt)Tt1\alpha_1 = \frac{C(t) - C(T_t)}{|T_t| - 1},此时就比如会发生剪枝。但是事实上并不是对于任何一个子节点都需要进行剪枝,因此有另一个假设:剪枝发生后,当前决策树是最优子树。当剪枝已经发生时,我们知道对于每一个子节点tt会有不同的α\alpha,此时记为Cα(t)=C(t)+α(t)C_{\alpha}(t) = C(t) + \alpha(t),我们需要找到tt,使得mintC(t)+α(t)\min_t C(t) + \alpha(t),这里转化为最小化mint(C(t)C(Tt)Tt1)\min_t (\frac{C(t) - C(T_t)}{|T_t| - 1})找到了最小的α\alpha即找到了最小的tt,从而就完成了剪枝。再考虑到Breiman等人已经证明过,可以利用递归的方法对数进行剪枝,使得最优子树序列{T0,T1,,Tn}\{T_0,T_1,\cdots,T_n\}对应区间α[αi,αi+1]\alpha \in [\alpha_i, \alpha_{i+1}],并且子树是嵌套的。
  2. 交叉验证选择最优子树。在第一步中我们得到了最优子树序列{T0,T1,,Tn}\{T_0,T_1,\cdots,T_n\}。利用验证集计算子树序列的平方误差,选择平方误差最小的决策树为最优决策树。

那么接下来我们考虑用代码实现上述回归树构建和剪枝过程。
先做一些准备工作,包括载入数据,以及实现切分函数,不过切分函数里的最优参数需要在后面确定:
机器学习实战(8):回归树
接下来考虑构建切分函数,切分函数需要找到切分数据集的最佳变量以及最佳切分点,同时还要能够生成相应的叶节点。
机器学习实战(8):回归树机器学习实战(8):回归树
有了这个切分函数之后,我们就可以构建回归树了:
机器学习实战(8):回归树
构建了回归树之后,为了能让回归树有更好的预测效果,我们需要对回归树进行剪枝。实际上之前的程序里已经有剪枝的思想了,比如我们设定了切分时的两个变量tolS,tolN,通过设定这两个参数我们可以提前终止构建回归树。但是这种误差常常需要人为指定,存在主观性,效果并不是很好。

接下来我们考虑对回归树进行简化的CART剪枝。
机器学习实战(8):回归树机器学习实战(8):回归树

相关文章: