机器学习中决策树
一、简介
1.定义
决策树是一种属性结构,其中:内部节点:代表对某个特征的判断(特征),分支:代表判断结果,叶子节点:代表最终分类结果(标签)。
2.决策树构建三步骤
(1)特征选择:筛选对分类 / 回归贡献最大的特征
(2)决策树生成:基于选定特征递归分裂数据集,生成初步树结构;
(3)剪枝:缓解过拟合(初步树可能过度贴合训练数据,泛化能力差,需裁剪冗余分支)。
注:决策树的核心优势是可解释性极强(类似 “if-else” 逻辑,易理解),无需特征标准化(如年龄和收入无需统一量级),但缺点是易过拟合,需依赖剪枝优化。
二、ID3决策树(离散特征专用)
1.核心指标:信息熵与信息增益
(1)信息熵(Entropy)
定义:信息论中衡量数据不确定性的指标,熵越大,数据混乱度(不确定性)越高;熵越小,数据纯度越高。
计算公式:对于数据集D,若包含k类样本,各类别占比为p1,p2...pk,则信息熵为:(对数底为 2,单位为 “比特”)
案例验证:
数据α(ABCDEFGH):8 类样本,每类占比(1/8),(H(α) = -8×(1/8)log_2(1/8) = 3);
数据β(AAAABBCD):4 类样本(A:1/2,B:1/4,C:1/8,D:1/8),(H(β) = -(1/2log_2 1/2 + 1/4log_2 1/4 + 2×1/8log_2 1/8) ≈ 1.75);
结论:(H(α) > H(β)),数据 α 更混乱。
注:信息熵由香农(Shannon)在 1948 年《通信的数学理论》中提出,最初用于衡量通信中的 “信息不确定性”,后来被引入机器学习作为特征选择的核心指标。
(2)信息增益
定义:特征a对数据集D的信息增益,等于 “数据集原熵H(D)” 减去 “特征a条件下的条件熵H(D|a)”,代表通过特征a分裂后,数据不确定性减少的程度。
公式:g(D,a)=H(D)-H(D|a)
(条件熵H(D|a):按特征α划分后的各子集熵的加权平均,权重为子集占总数据集的比例)
案例计算(6 个样本:3A、3B,特征a分 α(4 样本:3A1B)、β(2 样本:2B)):
原熵(H(D) = -(3/6log_2 3/6 + 3/6log_2 3/6) = 1);
条件熵(H(D|a) = (4/6)×[-(3/4log_2 3/4 + 1/4log_2 1/4)] + (2/6)×[-(2/2log_2 2/2)] ≈ 0.54);
信息增益g(D,a)=1-0.54=0.46。
2.ID3决策树构建流程
(1)计算数据集中所有特征的信息增益;
(2)选择信息增益最大的特征作为当前节点的分裂特征;
(3)按该特征的取值将数据集拆分为子集;
(4)对每个子集重复步骤 1-3,直到所有子集纯度达到阈值(如子集全为同一类别)或无特征可分裂。
3.典型案例:论坛客户流失分析
数据:15 条样本(5 个流失正样本,10 个未流失负样本),特征为 “性别”“活跃度”;
步骤:计算原熵→分别计算 “性别”“活跃度” 的信息增益→比较得出 “活跃度信息增益更大”,对流失的影响更显著。
注:ID3 的致命缺陷:①仅支持离散特征(无法处理年龄、收入等连续特征);②偏向选择取值数量多的特征(如 “用户 ID” 这类唯一值特征,信息增益极高,但无实际意义),后续 C4.5 决策树专门解决此问题。
三、C4.5决策树(ID3改进版)
1.ID3的痛点与C4.5的改进方向
ID3痛点:偏向选择取值多的特征(如特征 b 有 6 个取值,特征 a 仅 2 个,ID3 易选 b,但 b 可能无实际预测价值);
C4.5改进:用信息增益率替代信息增益,引入 “惩罚系数” 修正多取值特征的优势。
2.核心指标:信息增益率
定义:信息增益率 = 信息增益 / 特征熵(特征熵:以特征a为随机变量的熵,衡量特征自身的不确定性);
惩罚逻辑:若特征a取值多,其特征熵大,信息增益率会被 “稀释”,从而避免偏向多取值特征;
案例验证(特征 a:2 取值,特征 b:6 取值):
特征b的信息增益可能高,但特征熵更大(6 个唯一值,熵≈2.58),信息增益率低;
特征a的信息增益率更高,最终被选为分裂特征。
3.C4.5的额外优势
支持连续特征:通过 “离散化” 处理(如将年龄分为 [0-18,19-30,31+]);
处理缺失值:用 “样本权重” 弥补(如某样本缺失 “年龄”,则按其他特征的分布分配权重);
剪枝优化:自带后剪枝逻辑,降低过拟合风险。
注:C4.5 的局限:①计算复杂度高(需多次计算熵和增益率);②无法处理超大数据集(需将数据全部载入内存,不支持分布式),工业界常用 CART 或随机森林(基于 CART)替代。
四、CART决策树(分类+回归双用途)
CART(Classification and Regression Tree)是最常用的决策树模型,支持分类任务(用基尼指数)和回归任务(用平方损失),且无论特征类型,均生成二叉树(每个节点仅分 2 个分支)
1.CART分类树(预测离散类别)
(1)核心指标:基尼指数
定义:从数据集D中随机抽取 2 个样本,其类别标记不一致的概率,取值范围 [0,0.5],值越小,数据纯度越高;
公式:
基尼指数计算逻辑:选择使 “分裂后总基尼指数最小” 的特征和分裂点(如离散特征按 “是否为某取值” 分,连续特征按 “是否大于某阈值” 分)。
(2)典型案例:是否拖欠贷款预测
数据:10 条样本,特征为 “是否有房”“婚姻状况”“年收入”,目标为 “是否拖欠贷款”;
计算关键:
[是否有房]:有房样本(3 个,全为 “不拖欠”,基尼 = 0),无房样本(7 个,4 不拖欠 3 拖欠,基尼≈0.49),总基尼指数 =(3/10)×0 +(7/10)×0.49≈0.343;
[婚姻状况]:按 “是否已婚” 分裂,已婚样本(4 个,全不拖欠,基尼 = 0),未婚样本(6 个,3 不拖欠 3 拖欠,基尼 = 0.5),总基尼指数 =(4/10)×0 +(6/10)×0.5=0.3;
[年收入]:按 97.5 为阈值分裂,总基尼指数 = 0.3;
结论:选择 “婚姻状况(是否已婚)” 或 “年收入(97.5)” 作为分裂特征(基尼指数最小)。
2.CART回归树(预测连续值)
(1)与分类树的核心区别
维度 | CART 分类树 | CART 回归树 |
---|---|---|
输出类型 | 离散类别(如 “拖欠 / 不拖欠”) | 连续值(如 “房价”“收入”) |
损失函数 | 基尼指数 | 平方损失(Loss=(f(x)-y)^2) |
叶子节点输出 | 子集中多数类别的标签 | 子集中所有样本的均值 |
(2)CART回归树构建流程
1.对特征x的取值排序,取相邻值的均值作为候选分裂点(如(x=[1,2,...,10]),候选点为 1.5,2.5,...,9.5);
2.对每个候选点,将数据分为 “≤分裂点” 和 “>分裂点” 两个子集,计算两个子集的平方损失之和;
3.选择平方损失之和最小的分裂点作为当前特征的最优分裂点;
4.对每个子集重复步骤 1-3,直到子集满足停止条件(如子集样本数≤阈值)。
(3)典型案例:特征x与目标y的回归
数据:x=[1-10],(y=[5.56,5.7,5.91,6.4,6.8,7.05,8.9,8.7,9,9.05];
步骤:
1.候选分裂点为 1.5-9.5,计算各点平方损失,发现s=6.5时损失最小(m(s)=1.93);
2.按s=6.5分裂为左子集(x≤6,6 个样本)和右子集(x>6,4 个样本);
3.对左子集继续分裂,发现s=3.5时损失最小,最终生成二叉树。
注:CART 的核心优势:①支持分类 + 回归双任务;②二叉树结构简洁,计算效率高;③可处理离散 / 连续特征,工业界应用最广(如随机森林、GBDT 等集成模型的基础均为 CART)。
五、泰塔尼克案例
1.案例背景
数据:泰坦尼克号乘客数据(特征:Pclass(舱位等级)、Age(年龄)、Sex(性别)等,目标:Survived(是否生存));
核心逻辑:通过决策树预测乘客是否能生存(历史事实:妇女、儿童、高舱位乘客生存率更高)。
2.实现步骤(Python+sklearn)
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, roc_curve
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# 设置中文字体,解决中文显示问题
plt.rcParams["font.family"] = ["SimHei"]
# 解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False#读取数据
data = pd.read_csv('../data/train.csv')
# data.info()
#数据预处理
# df = data.isnull().sum()
# print(df)
data['Sex'] = data['Sex'].map({'female': 0, 'male': 1})
data['Age'] = data['Age'].fillna(data['Age'].mean()) # 数值列填充
# data.fillna({'Age':data['Age'].mean()}, inplace=True)
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
data['Embarked'] = data['Embarked'].map({'S':0, 'C':1, 'Q':2})
data.drop('Cabin', axis=1, inplace=True)
data.info()X = data[['Pclass','Sex','Age', 'SibSp', 'Parch','Fare', 'Embarked']]
y = data['Survived']
# X = pd.get_dummies(X['Embarked'], prefix='Embarked')
# print(X.shape)
#特征处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=5999,stratify=y)
# transfer = StandardScaler()
# X_train = transfer.fit_transform(X_train)
# X_test = transfer.transform(X_test)
#创建模型
es = DecisionTreeClassifier()
#训练模型
es.fit(X_train,y_train)
#模型预测
y_predict = es.predict(X_test)
#模型评估
y_proba = es.predict_proba(X_test)[:,1]
print(f"准确值:{accuracy_score(y_test, y_predict)}")
print(f"召回率:{recall_score(y_test, y_predict)}")
print(f"精准率:{precision_score(y_test, y_predict)}")
print(f"f1分数:{f1_score(y_test, y_predict)}")
print(f"AUC指标:{roc_auc_score(y_test, y_proba)}")plt.figure(figsize=(8,4))
plot_tree(es,filled=True,max_depth=30)
plt.show()
3.关键结论
重要特征:Sex(性别)>Pclass(舱位)>Age(年龄),符合 “妇女优先、高舱位优先” 的历史事实;
调参建议:通过max_depth
(树深)、min_samples_leaf
(叶子节点最小样本数)限制树的复杂度,降低过拟合。
注:实际项目中,决策树的调参核心是控制复杂度:①max_depth
不宜过大(如超过 10 易过拟合);②min_samples_split
(内部节点分裂最小样本数)设为 5-10,避免小样本分裂;③ccp_alpha
(成本复杂度剪枝参数)可自动剪枝,sklearn 推荐使用。
六、决策树剪枝(防止过拟合的核心手段)
1.剪枝的必要性
决策树若不剪枝,会过度贴合训练数据(如每个叶子节点仅 1 个样本),导致泛化能力差(测试集准确率低),剪枝通过 “删除冗余分支”,保留核心逻辑,提升泛化能力。
2.两种剪枝方法对比
剪枝类型 | 核心逻辑 | 优点 | 缺点 |
---|---|---|---|
预剪枝 | 树生成过程中,对每个节点先判断:若分裂后验证集准确率无提升,则停止分裂,标记为叶子节点 | 训练 / 测试效率高,避免冗余分支生成 | 可能欠拟合(过早停止分裂,错过后续有效分支) |
后剪枝 | 先生成完整决策树,再自底向上遍历非叶节点:若将子树替换为叶子节点后准确率提升,则剪枝 | 泛化能力强,欠拟合风险低 | 训练效率低(需生成完整树再剪枝) |
3.案例:西瓜数据集剪枝
预剪枝:生成 “脐部 =?” 节点后,若 “色泽 =?” 分裂后验证集准确率从 71.4% 降至 57.1%,则禁止分裂,保留 “脐部” 节点;
后剪枝:先生成含 6 个内部节点的完整树,再判断 “色泽 =?”“纹理 =?” 等节点,若剪枝后准确率提升(如从 57.1%→71.4%),则剪枝。
注:工业界常用后剪枝 + 预剪枝结合:①先用预剪枝快速生成基础树;②再用后剪枝优化局部分支;③sklearn 中DecisionTreeClassifier
的ccp_alpha
参数实现 “成本复杂度剪枝”(后剪枝的一种),通过最小化 “训练损失 +α× 树复杂度” 选择最优树。
七、三大分类决策树对比
决策树类型 | 提出时间 | 分支方式 | 支持特征类型 | 核心优势 | 核心局限 |
---|---|---|---|---|---|
ID3 | 1975 | 信息增益 | 仅离散 | 计算简单,易理解 | 不支持连续 / 缺失值,偏多取值特征 |
C4.5 | 1993 | 信息增益率 | 离散 + 连续 | 解决 ID3 偏倚,支持缺失值 | 计算复杂,不支持大数据集 |
CART | 1984 | 基尼指数(分类) | 离散 + 连续 | 支持分类 / 回归,二叉树高效 | 对异常值敏感 |
平方损失(回归) |
八、核心知识点梳理
1.决策树定义:内部节点是特征判断,叶子是分类结果,可明确特征重要性;
2.信息熵:熵越大,混乱度越高
3.信息熵增率:缓解多取值特征偏倚,C4.5 核心
4.基尼指数:CART 分类树核心,值越小纯度越高
5.剪枝方法:预剪枝(生成时剪)、后剪枝(生成后剪),均为防过拟合