机器学习之决策树:从原理到实战(附泰坦尼克号预测任务)
大家好!今天我们来深入探讨机器学习中经典且实用的算法 —— 决策树。无论是数据分类还是回归任务,决策树都以其直观易懂、可解释性强的特点被广泛应用。本文将从决策树的核心算法(ID3、C4.5、CART)讲起,带你理解连续值处理、剪枝策略,最后通过代码实战完成泰坦尼克号幸存者预测任务,保证内容详实且易上手。
一、决策树核心算法:从信息增益到基尼指数
决策树的本质是通过 “划分属性” 构建一棵 “判断树”,每一个内部节点代表一次属性判断,每一个叶子节点代表最终的类别(分类任务)或输出值(回归任务)。不同算法的核心差异在于划分属性的选择标准,常见的有 ID3、C4.5、CART 三种。
1.1 ID3 算法:以 “信息增益” 为核心
ID3 是最早的决策树算法之一,它通过 “信息增益” 选择最优划分属性,核心逻辑是:信息增益越大,用该属性划分后数据集的 “纯度提升” 越明显。
关键概念
- 信息熵(Entropy):衡量数据集纯度的指标,熵越小,数据越纯。公式为:
Entropy(D)=−∑k=1npklog2pk
其中pk是数据集D中第k类样本的占比。 - 信息增益:某属性a对数据集D的信息增益,等于 “划分前的熵” 减去 “划分后各子集熵的加权和”,公式为:
Gain(D,a)=Entropy(D)−∑v∈Values(a)∣D∣∣Dv∣Entropy(Dv)
局限性
ID3 对可取值数目多的属性有天然偏好。例如 “样本编号” 这类属性,每个取值对应一个子集,划分后熵为 0,信息增益最大,但显然无实际意义(过拟合风险极高)。
1.2 C4.5 算法:用 “信息增益率” 修正偏差
为解决 ID3 的属性偏好问题,C4.5 引入 “信息增益率” 作为划分标准,通过 “自身熵” 对信息增益进行归一化。
关键概念
- 信息增益率:公式为:
Gain_ratio(D,a)=IV(a)Gain(D,a)
其中IV(a)=−∑v∈Values(a)∣D∣∣Dv∣log2∣D∣∣Dv∣ 是属性a的 “固有值”—— 属性取值越多,IV(a)越大,从而抑制对多取值属性的偏好。
示例:基于 “是否出去玩” 数据集
以 PPT 中的示例数据集为例(共 7 条样本,特征包括天气、温度、湿度、是否多云,标签为 “是否出去玩”):
ID | 天气 | 温度 | 湿度 | 是否多云 | 是否出去玩 |
---|---|---|---|---|---|
1 | 晴 | 高 | 高 | 是 | 是 |
2 | 阴 | 适中 | 高 | 是 | 否 |
3 | 雨 | 高 | 正常 | 否 | 否 |
4 | 雨 | 适中 | 正常 | 否 | 是 |
5 | 阴 | 高 | 高 | 是 | 是 |
6 | 晴 | 低 | 正常 | 否 | 是 |
7 | 阴 | 低 | 正常 | 是 | 是 |
若用 C4.5 计算 “天气” 属性的信息增益率:
- 先计算划分前的总熵Entropy(D);
- 计算 “天气” 属性各取值(晴、阴、雨)对应的子集熵,得到信息增益天气;
- 计算 “天气” 属性的固有值天气;
- 最终得到信息增益率天气,并与其他属性(如温度、湿度)比较,选择最大的作为划分属性。
1.3 CART 算法:用 “基尼指数” 简化计算
CART(分类与回归树)是目前应用最广的决策树算法,支持分类和回归任务。在分类任务中,它用 “基尼指数” 替代信息熵,计算效率更高(无需对数运算)。
关键概念
- 基尼指数(Gini (D)):反映从数据集D中随机抽取两个样本,类别标记不一致的概率。公式为:
Gini(D)=1−∑k=1npk2
其中pk是第k类样本的占比。
核心逻辑:pk越大(数据越纯),Gini (D) 越小。例如,若数据集全为同一类,Gini (D)=0;若两类样本各占 50%,Gini (D)=0.5(纯度最低)。
CART 的特点
- 只生成二叉树:无论属性有多少取值,每次划分都将数据分为两部分(如 “温度≤25℃” 和 “温度 > 25℃”);
- 适用场景广:分类任务用基尼指数 / 信息熵,回归任务用均方误差(MSE);
- 计算高效:避免信息熵的对数运算,适合大规模数据。
二、决策树的 “难点突破”:连续值处理
现实数据中,很多特征是连续值(如年龄、收入),而 ID3/C4.5/CART 原生只支持离散值。如何处理连续值?核心思路是将连续值 “离散化”,具体步骤如下(基于 PPT 示例):
步骤 1:排序连续值
以 PPT 中 “应税收入(Taxable Income)” 为例,10 条样本的收入数据为:60K、70K、75K、85K、90K、95K、100K、120K、125K、220K,先按从小到大排序。
步骤 2:确定候选分界点
连续值的 “二分法” 划分中,候选分界点为相邻两个值的中点。例如上述 10 个值有 9 个候选分界点:65K((60+70)/2)、72.5K((70+75)/2)、…、172.5K((125+220)/2)。
步骤 3:用 “贪婪算法” 选最优分界点
对每个候选分界点,计算划分后的 “信息增益”(ID3/C4.5)或 “基尼指数”(CART),选择最优的分界点。例如:
- 若用 “Taxable Income ≤80K” 划分,得到两个子集,计算其信息增益;
- 若用 “Taxable Income ≤97.5K” 划分,计算另一组信息增益;
- 最终选择信息增益最大(或基尼指数最小)的分界点作为该属性的划分标准。
本质上,这一步是将连续特征转化为 “是 / 否” 的离散特征,从而适配决策树的划分逻辑。
三、决策树的 “防坑指南”:剪枝策略
决策树有一个致命问题 ——过拟合风险极高。理论上,它可以无限分裂节点,直到每个叶子节点只包含一个样本(完全拟合训练数据,但对新数据泛化能力差)。因此,“剪枝” 是决策树不可或缺的步骤,核心是通过 “牺牲部分训练精度” 换取 “更好的泛化能力”,常见策略分为预剪枝和后剪枝。
3.1 预剪枝:“边建边剪”,高效实用
预剪枝是在决策树构建过程中同步进行剪枝,提前停止节点分裂,避免过拟合。它的优点是计算量小、效率高,是工业界更常用的方式。
常见预剪枝规则
- 限制树的最大深度:例如设置
max_depth=10
,树的深度超过 10 则停止分裂(深度越大,过拟合风险越高,推荐值 5-20); - 限制叶子节点最少样本数:例如设置
min_samples_leaf=5
,若某节点分裂后叶子节点样本数 < 5,则不分裂; - 限制叶子节点总数:例如设置
max_leaf_nodes=20
,叶子节点数超过 20 则停止; - 限制信息增益阈值:若某属性的信息增益(或增益率、基尼指数)低于阈值,则不分裂该节点。
预剪枝的特点
- 优点:操作简单、计算高效,能有效避免过拟合;
- 缺点:可能 “欠剪枝”—— 因提前停止分裂,导致模型未充分学习数据规律,精度偏低。
3.2 后剪枝:“先建后剪”,精度更高
后剪枝是先构建完整的决策树(不限制分裂),再从叶子节点向根节点反向剪枝,通过 “损失函数” 判断是否保留子树。它的优点是泛化能力更强,缺点是计算量较大。
核心:损失函数与 α 参数
后剪枝的核心是比较 “保留子树” 和 “剪枝为叶子节点” 的损失,选择损失更小的方案。损失函数公式(基于 CART 分类树):
Loss(T)=C(T)+α×∣T∣
- C(T):子树T的 “训练误差”,分类任务中常用 “所有叶子节点的基尼指数加权和”;
- ∣T∣:子树T的叶子节点数量;
- α:正则化参数,控制剪枝强度:
- α越大:正则化越强,更倾向于剪枝(减少叶子节点数),避免过拟合,但可能降低精度;
- α越小:正则化越弱,更倾向于保留复杂子树,精度可能更高,但过拟合风险大。
后剪枝示例(基于 PPT“好瓜 / 坏瓜” 案例)
以 PPT 中 “色泽 =?” 和 “纹理 =?” 分支的剪枝为例:
- 对 “色泽 =?” 分支:剪枝前验证集精度 57.1%,剪枝后 71.4%(损失更小),因此选择剪枝;
- 对 “纹理 =?” 分支:剪枝前验证集精度 42.9%,剪枝后 57.1%(损失更小),因此选择剪枝。
后剪枝的关键是用验证集评估精度,确保剪枝后的模型对新数据有更好的表现。
四、代码实战:用决策树预测泰坦尼克号幸存者
理论讲完,我们通过实战巩固 —— 用scikit-learn
的DecisionTreeClassifier
构建决策树,预测泰坦尼克号乘客是否幸存。
4.1 数据准备
首先需要获取泰坦尼克号数据集(若未下载,可从Kaggle 泰坦尼克号竞赛页面下载,核心文件为train.csv
)。数据集包含乘客的基本信息(如年龄、性别、船票等级)和是否幸存的标签(Survived
:1 = 幸存,0 = 遇难)。
4.2 核心代码实现
步骤 1:导入所需库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder # 处理分类特征
from sklearn.metrics import accuracy_score, classification_report # 评估模型
步骤 2:加载并预处理数据
决策树无法直接处理字符串特征(如性别、Embarked),需先进行编码;同时处理缺失值(如年龄缺失):
# 加载数据(确保文件路径正确,若报错FileNotFoundError,参考前文解决)
data = pd.read_csv("train.csv", index_col=0) # index_col=0用PassengerId作为索引# 1. 选择核心特征(剔除无关特征如Name、Ticket、Cabin)
features = ["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked"]
X = data[features]
y = data["Survived"] # 标签:是否幸存# 2. 处理缺失值(年龄用均值填充,Embarked用众数填充)
X["Age"].fillna(X["Age"].mean(), inplace=True)
X["Embarked"].fillna(X["Embarked"].mode()[0], inplace=True)# 3. 编码分类特征(Sex:男=1,女=0;Embarked:C=0,Q=1,S=2)
le_sex = LabelEncoder()
X["Sex"] = le_sex.fit_transform(X["Sex"])le_embarked = LabelEncoder()
X["Embarked"] = le_embarked.fit_transform(X["Embarked"])# 4. 划分训练集和测试集(8:2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
步骤 3:构建并训练决策树
基于DecisionTreeClassifier
,关键参数参考 PPT 说明:
# 初始化决策树模型(用基尼指数,限制最大深度为5,避免过拟合)
dt_model = DecisionTreeClassifier(criterion="gini", # 划分标准:gini(基尼指数)或entropy(信息熵)max_depth=5, # 最大深度:预剪枝策略,避免过拟合min_samples_leaf=5,# 叶子节点最少样本数:预剪枝策略random_state=42 # 固定随机种子,保证结果可复现
)# 训练模型
dt_model.fit(X_train, y_train)# 交叉验证评估(5折交叉验证,更可靠)
cv_scores = cross_val_score(dt_model, X_train, y_train, cv=5)
print(f"5折交叉验证准确率:{cv_scores.mean():.2f} ± {cv_scores.std():.2f}")
步骤 4:模型评估与可视化
# 1. 在测试集上评估精度
y_pred = dt_model.predict(X_test)
test_acc = accuracy_score(y_test, y_pred)
print(f"测试集准确率:{test_acc:.2f}")# 2. 输出分类报告(精确率、召回率、F1值)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=["遇难", "幸存"]))# 3. 可视化决策树(直观查看划分逻辑)
plt.figure(figsize=(15, 10)) # 设置图大小
plot_tree(dt_model,feature_names=features,class_names=["遇难", "幸存"],filled=True, # 用颜色填充节点(颜色越深,纯度越高)fontsize=10
)
plt.title("泰坦尼克号幸存者预测决策树", fontsize=15)
plt.show()
步骤 5:参数调优(可选)
通过GridSearchCV
优化关键参数(如max_depth
、min_samples_leaf
):
from sklearn.model_selection import GridSearchCV# 定义参数网格
param_grid = {"max_depth": [3, 5, 7, 10],"min_samples_leaf": [2, 5, 10],"criterion": ["gini", "entropy"]
}# 网格搜索(5折交叉验证)
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5)
grid_search.fit(X_train, y_train)# 输出最优参数和最优分数
print(f"最优参数:{grid_search.best_params_}")
print(f"最优交叉验证准确率:{grid_search.best_score_:.2f}")# 用最优模型预测
best_dt = grid_search.best_estimator_
best_y_pred = best_dt.predict(X_test)
print(f"最优模型测试集准确率:{accuracy_score(y_test, best_y_pred):.2f}")
4.3 结果分析
- 决策树的可视化结果会显示:每个节点的划分特征(如 “Sex”)、划分阈值(如 “Sex≤0.5”,即女性)、节点的基尼指数、样本数量、类别分布;
- 通常情况下,“性别(Sex)” 是泰坦尼克号幸存者预测的最重要特征(女性幸存率远高于男性),其次是 “船票等级(Pclass)”(一等舱幸存率高于二等舱、三等舱);
- 通过预剪枝(如限制
max_depth=5
),模型测试集准确率可达 75%-85%,泛化能力较好。
五、总结
决策树作为机器学习领域经典且实用的算法,凭借直观易懂、可解释性强的特性,在分类与回归任务中均占据重要地位。本文围绕其原理与实战展开深度剖析,为你搭建从理论理解到实践应用的完整知识链路:
5.1算法演进:从基础到优化
- ID3:以信息增益为划分依据,开启决策树构建先河,却因对多取值属性的天然偏好,易引发过拟合风险,限制了实际场景的普适性。
- C4.5:引入信息增益率,通过固有值对信息增益归一化,有效修正 ID3 的属性偏好问题,提升了决策树构建的合理性。
- CART:支持二叉树结构,分类任务采用基尼指数(计算高效,避免对数运算),回归任务适配均方误差,兼顾效率与泛化能力,成为工业界常用算法。
5.2关键技术:突破应用壁垒
- 连续值处理:采用排序、确定候选分界点、贪婪选最优的策略,将连续特征离散化,让决策树能适配含连续值的真实数据集,拓宽应用场景。
- 剪枝策略:预剪枝通过限制树深度、叶子节点样本数等,边构建边干预,计算高效但可能欠拟合;后剪枝先构建完整树,再反向剪枝,依赖验证集评估,虽计算量大,但泛化能力更优,二者配合可有效平衡过拟合与欠拟合。
5.3实战价值:泰坦尼克号预测验证
通过泰坦尼克号幸存者预测实战,完成数据预处理(缺失值填充、分类特征编码 )、模型构建(用DecisionTreeClassifier
,结合预剪枝参数调优)、评估(交叉验证、分类报告)全流程。结果显示,决策树能有效捕捉关键特征(如性别、船票等级对幸存率的影响 ),经合理剪枝后,测试集准确率可达 75%-85%,验证了算法在真实场景的应用价值。
从算法原理的层层拆解,到关键技术的逐个突破,再到实战项目的落地验证,决策树展现出强大的可解释性与实用性。掌握它,不仅能助力你解决分类回归问题,更能让你理解 “如何用简单逻辑挖掘数据规律”,为后续进阶复杂模型(如随机森林、梯度提升树 )筑牢基础,在机器学习的探索之路上稳步前行。