决策树算法学习笔记
一、引言
决策树是机器学习中一种常用的分类与回归算法,其核心思想是通过对数据特征的逐步划分,构建一棵类似树状的决策模型。本次学习主要围绕决策树的经典算法(ID3、C4.5、CART)、关键技术(连续值处理、剪枝策略)及代码实现展开。
二、核心决策树算法
(一)ID3 算法
核心指标:信息增益 信息增益是某个属性带来的熵减(纯度提升),计算公式基于信息熵的变化。信息增益越大,说明使用该属性划分数据获得的 “纯度提升” 越显著,因此 ID3 算法以信息增益作为划分属性的选择标准。
局限性 信息增益准则对可取值数目较多的属性存在偏好。例如,若数据集中存在 “编号” 这类唯一取值的属性,其信息增益可能最大,但显然不具备实际分类意义。
(二)C4.5 算法
为改进 ID3 的属性偏好问题,C4.5 算法引入信息增益率作为划分标准: 信息增益率 = 信息增益 ÷ 属性自身的熵 通过除以属性自身的熵,可对取值较多的属性进行 “惩罚”,缓解偏好问题。
(三)CART 算法
核心指标:基尼指数 基尼指数用于衡量数据集的纯度,计算公式为:\(Gini(D) = 1 - \sum_{k=1}^{n}p_k^2\),其中\(p_k\)是数据集D中第k类样本的占比。 基尼指数反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。\(p_k\)越大(即某类样本占比越高),基尼指数越小,数据集纯度越高。
三、连续值处理方法
现实数据中存在大量连续值特征(如收入、年龄等),决策树处理连续值的核心是离散化,具体步骤如下:
排序:对连续特征的所有取值进行排序;
找分界点:采用贪婪算法,在排序后的相邻取值中间选取可能的分界点进行二分。例如,若特征取值为 [60,70,75,85,90,95,100,120,125,220],则存在 9 个可能的分界点,可尝试分割为 “≤80” 与 “>80”“≤97.5” 与 “>97.5” 等组合,选择最优分界点。
四、剪枝策略(解决过拟合)
决策树若不加以限制,理论上可完全拟合训练数据,导致过拟合风险极高,因此需要剪枝优化。
(一)预剪枝
定义:边构建决策树边进行剪枝的操作,更具实用性。
实现方式:通过设置限制条件停止树的生长,例如:
限制树的最大深度;
限制叶子节点的最少样本数;
限制信息增益的最小值(若划分的信息增益低于阈值则停止)。
(二)后剪枝
定义:待决策树完全构建完成后再进行剪枝。
衡量标准:通过损失函数判断是否剪枝,损失函数公式为: 最终损失 = 子树自身的基尼系数值 + α× 叶子节点数量
α 越大:对叶子节点数量的惩罚越重,树越简单,过拟合风险越低,但可能牺牲模型精度;
α 越小:更注重模型精度,过拟合风险可能较高。
案例:某分支剪枝前验证集精度为 57.1%,剪枝后提升至 71.4%,说明剪枝有效降低了过拟合。
五、决策树代码实现(基于DecisionTreeClassifier
)
scikit-learn
中的DecisionTreeClassifier
是常用的决策树实现工具,核心参数如下:
参数 | 含义 |
---|---|
criterion | 划分标准,可选gini (基尼指数)或entropy (信息熵) |
splitter | 切分点选择方式,best (所有特征中找最优)或random (部分特征中找) |
max_features | 划分时考虑的最大特征数,可选None (所有)、log2 、sqrt 或具体数值 |
max_depth | 树的最大深度,深度越大越易过拟合,推荐设置为 5-20 之间 |
六、实践应用
课堂练习中提出使用决策树对泰坦尼克号幸存者进行预测,通过选择合适的特征(如年龄、性别、票价等),利用上述算法与参数设置构建模型,可实现对幸存者的分类预测。
总结
决策树算法通过直观的树状结构实现分类决策,ID3、C4.5、CART 分别基于信息增益、信息增益率、基尼指数进行特征划分;连续值通过离散化处理融入模型;预剪枝与后剪枝策略有效缓解过拟合问题。在实际应用中,需根据数据特点选择合适的算法与参数,以提升模型性能。