当前位置: 首页 > news >正文

机器学习之决策树:从原理到实战(附泰坦尼克号预测任务)

大家好!今天我们来深入探讨机器学习中经典且实用的算法 —— 决策树。无论是数据分类还是回归任务,决策树都以其直观易懂、可解释性强的特点被广泛应用。本文将从决策树的核心算法(ID3、C4.5、CART)讲起,带你理解连续值处理、剪枝策略,最后通过代码实战完成泰坦尼克号幸存者预测任务,保证内容详实且易上手。

一、决策树核心算法:从信息增益到基尼指数

决策树的本质是通过 “划分属性” 构建一棵 “判断树”,每一个内部节点代表一次属性判断,每一个叶子节点代表最终的类别(分类任务)或输出值(回归任务)。不同算法的核心差异在于划分属性的选择标准,常见的有 ID3、C4.5、CART 三种。

1.1 ID3 算法:以 “信息增益” 为核心

ID3 是最早的决策树算法之一,它通过 “信息增益” 选择最优划分属性,核心逻辑是:信息增益越大,用该属性划分后数据集的 “纯度提升” 越明显

关键概念
  • 信息熵(Entropy):衡量数据集纯度的指标,熵越小,数据越纯。公式为:
    Entropy(D)=−∑k=1n​pk​log2​pk​
    其中pk​是数据集D中第k类样本的占比。
  • 信息增益:某属性a对数据集D的信息增益,等于 “划分前的熵” 减去 “划分后各子集熵的加权和”,公式为:
    Gain(D,a)=Entropy(D)−∑v∈Values(a)​∣D∣∣Dv​∣​Entropy(Dv​)
局限性

ID3 对可取值数目多的属性有天然偏好。例如 “样本编号” 这类属性,每个取值对应一个子集,划分后熵为 0,信息增益最大,但显然无实际意义(过拟合风险极高)。

1.2 C4.5 算法:用 “信息增益率” 修正偏差

为解决 ID3 的属性偏好问题,C4.5 引入 “信息增益率” 作为划分标准,通过 “自身熵” 对信息增益进行归一化。

关键概念
  • 信息增益率:公式为:
    Gain_ratio(D,a)=IV(a)Gain(D,a)​
    其中IV(a)=−∑v∈Values(a)​∣D∣∣Dv​∣​log2​∣D∣∣Dv​∣​ 是属性a的 “固有值”—— 属性取值越多,IV(a)越大,从而抑制对多取值属性的偏好。
示例:基于 “是否出去玩” 数据集

以 PPT 中的示例数据集为例(共 7 条样本,特征包括天气、温度、湿度、是否多云,标签为 “是否出去玩”):

ID天气温度湿度是否多云是否出去玩
1
2适中
3正常
4适中正常
5
6正常
7正常

若用 C4.5 计算 “天气” 属性的信息增益率:

  1. 先计算划分前的总熵Entropy(D);
  2. 计算 “天气” 属性各取值(晴、阴、雨)对应的子集熵,得到信息增益天气;
  3. 计算 “天气” 属性的固有值天气;
  4. 最终得到信息增益率天气,并与其他属性(如温度、湿度)比较,选择最大的作为划分属性。

1.3 CART 算法:用 “基尼指数” 简化计算

CART(分类与回归树)是目前应用最广的决策树算法,支持分类和回归任务。在分类任务中,它用 “基尼指数” 替代信息熵,计算效率更高(无需对数运算)。

关键概念
  • 基尼指数(Gini (D)):反映从数据集D中随机抽取两个样本,类别标记不一致的概率。公式为:
    Gini(D)=1−∑k=1n​pk2​
    其中pk​是第k类样本的占比。
    核心逻辑:pk​越大(数据越纯),Gini (D) 越小。例如,若数据集全为同一类,Gini (D)=0;若两类样本各占 50%,Gini (D)=0.5(纯度最低)。
CART 的特点
  • 只生成二叉树:无论属性有多少取值,每次划分都将数据分为两部分(如 “温度≤25℃” 和 “温度 > 25℃”);
  • 适用场景广:分类任务用基尼指数 / 信息熵,回归任务用均方误差(MSE);
  • 计算高效:避免信息熵的对数运算,适合大规模数据。

二、决策树的 “难点突破”:连续值处理

现实数据中,很多特征是连续值(如年龄、收入),而 ID3/C4.5/CART 原生只支持离散值。如何处理连续值?核心思路是将连续值 “离散化”,具体步骤如下(基于 PPT 示例):

步骤 1:排序连续值

以 PPT 中 “应税收入(Taxable Income)” 为例,10 条样本的收入数据为:60K、70K、75K、85K、90K、95K、100K、120K、125K、220K,先按从小到大排序。

步骤 2:确定候选分界点

连续值的 “二分法” 划分中,候选分界点为相邻两个值的中点。例如上述 10 个值有 9 个候选分界点:65K((60+70)/2)、72.5K((70+75)/2)、…、172.5K((125+220)/2)。

步骤 3:用 “贪婪算法” 选最优分界点

对每个候选分界点,计算划分后的 “信息增益”(ID3/C4.5)或 “基尼指数”(CART),选择最优的分界点。例如:

  • 若用 “Taxable Income ≤80K” 划分,得到两个子集,计算其信息增益;
  • 若用 “Taxable Income ≤97.5K” 划分,计算另一组信息增益;
  • 最终选择信息增益最大(或基尼指数最小)的分界点作为该属性的划分标准。

本质上,这一步是将连续特征转化为 “是 / 否” 的离散特征,从而适配决策树的划分逻辑。

三、决策树的 “防坑指南”:剪枝策略

决策树有一个致命问题 ——过拟合风险极高。理论上,它可以无限分裂节点,直到每个叶子节点只包含一个样本(完全拟合训练数据,但对新数据泛化能力差)。因此,“剪枝” 是决策树不可或缺的步骤,核心是通过 “牺牲部分训练精度” 换取 “更好的泛化能力”,常见策略分为预剪枝和后剪枝。

3.1 预剪枝:“边建边剪”,高效实用

预剪枝是在决策树构建过程中同步进行剪枝,提前停止节点分裂,避免过拟合。它的优点是计算量小、效率高,是工业界更常用的方式。

常见预剪枝规则
  • 限制树的最大深度:例如设置max_depth=10,树的深度超过 10 则停止分裂(深度越大,过拟合风险越高,推荐值 5-20);
  • 限制叶子节点最少样本数:例如设置min_samples_leaf=5,若某节点分裂后叶子节点样本数 < 5,则不分裂;
  • 限制叶子节点总数:例如设置max_leaf_nodes=20,叶子节点数超过 20 则停止;
  • 限制信息增益阈值:若某属性的信息增益(或增益率、基尼指数)低于阈值,则不分裂该节点。
预剪枝的特点
  • 优点:操作简单、计算高效,能有效避免过拟合;
  • 缺点:可能 “欠剪枝”—— 因提前停止分裂,导致模型未充分学习数据规律,精度偏低。

3.2 后剪枝:“先建后剪”,精度更高

后剪枝是先构建完整的决策树(不限制分裂),再从叶子节点向根节点反向剪枝,通过 “损失函数” 判断是否保留子树。它的优点是泛化能力更强,缺点是计算量较大。

核心:损失函数与 α 参数

后剪枝的核心是比较 “保留子树” 和 “剪枝为叶子节点” 的损失,选择损失更小的方案。损失函数公式(基于 CART 分类树):
Loss(T)=C(T)+α×∣T∣

  • C(T):子树T的 “训练误差”,分类任务中常用 “所有叶子节点的基尼指数加权和”;
  • ∣T∣:子树T的叶子节点数量;
  • α:正则化参数,控制剪枝强度:
    • α越大:正则化越强,更倾向于剪枝(减少叶子节点数),避免过拟合,但可能降低精度;
    • α越小:正则化越弱,更倾向于保留复杂子树,精度可能更高,但过拟合风险大。
后剪枝示例(基于 PPT“好瓜 / 坏瓜” 案例)

以 PPT 中 “色泽 =?” 和 “纹理 =?” 分支的剪枝为例:

  • 对 “色泽 =?” 分支:剪枝前验证集精度 57.1%,剪枝后 71.4%(损失更小),因此选择剪枝;
  • 对 “纹理 =?” 分支:剪枝前验证集精度 42.9%,剪枝后 57.1%(损失更小),因此选择剪枝。

后剪枝的关键是用验证集评估精度,确保剪枝后的模型对新数据有更好的表现。

四、代码实战:用决策树预测泰坦尼克号幸存者

理论讲完,我们通过实战巩固 —— 用scikit-learnDecisionTreeClassifier构建决策树,预测泰坦尼克号乘客是否幸存。

4.1 数据准备

首先需要获取泰坦尼克号数据集(若未下载,可从Kaggle 泰坦尼克号竞赛页面下载,核心文件为train.csv)。数据集包含乘客的基本信息(如年龄、性别、船票等级)和是否幸存的标签(Survived:1 = 幸存,0 = 遇难)。

4.2 核心代码实现

步骤 1:导入所需库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import LabelEncoder  # 处理分类特征
from sklearn.metrics import accuracy_score, classification_report  # 评估模型
步骤 2:加载并预处理数据

决策树无法直接处理字符串特征(如性别、Embarked),需先进行编码;同时处理缺失值(如年龄缺失):

# 加载数据(确保文件路径正确,若报错FileNotFoundError,参考前文解决)
data = pd.read_csv("train.csv", index_col=0)  # index_col=0用PassengerId作为索引# 1. 选择核心特征(剔除无关特征如Name、Ticket、Cabin)
features = ["Pclass", "Sex", "Age", "SibSp", "Parch", "Fare", "Embarked"]
X = data[features]
y = data["Survived"]  # 标签:是否幸存# 2. 处理缺失值(年龄用均值填充,Embarked用众数填充)
X["Age"].fillna(X["Age"].mean(), inplace=True)
X["Embarked"].fillna(X["Embarked"].mode()[0], inplace=True)# 3. 编码分类特征(Sex:男=1,女=0;Embarked:C=0,Q=1,S=2)
le_sex = LabelEncoder()
X["Sex"] = le_sex.fit_transform(X["Sex"])le_embarked = LabelEncoder()
X["Embarked"] = le_embarked.fit_transform(X["Embarked"])# 4. 划分训练集和测试集(8:2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
步骤 3:构建并训练决策树

基于DecisionTreeClassifier,关键参数参考 PPT 说明:

# 初始化决策树模型(用基尼指数,限制最大深度为5,避免过拟合)
dt_model = DecisionTreeClassifier(criterion="gini",  # 划分标准:gini(基尼指数)或entropy(信息熵)max_depth=5,       # 最大深度:预剪枝策略,避免过拟合min_samples_leaf=5,# 叶子节点最少样本数:预剪枝策略random_state=42    # 固定随机种子,保证结果可复现
)# 训练模型
dt_model.fit(X_train, y_train)# 交叉验证评估(5折交叉验证,更可靠)
cv_scores = cross_val_score(dt_model, X_train, y_train, cv=5)
print(f"5折交叉验证准确率:{cv_scores.mean():.2f} ± {cv_scores.std():.2f}")
步骤 4:模型评估与可视化
# 1. 在测试集上评估精度
y_pred = dt_model.predict(X_test)
test_acc = accuracy_score(y_test, y_pred)
print(f"测试集准确率:{test_acc:.2f}")# 2. 输出分类报告(精确率、召回率、F1值)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=["遇难", "幸存"]))# 3. 可视化决策树(直观查看划分逻辑)
plt.figure(figsize=(15, 10))  # 设置图大小
plot_tree(dt_model,feature_names=features,class_names=["遇难", "幸存"],filled=True,  # 用颜色填充节点(颜色越深,纯度越高)fontsize=10
)
plt.title("泰坦尼克号幸存者预测决策树", fontsize=15)
plt.show()
步骤 5:参数调优(可选)

通过GridSearchCV优化关键参数(如max_depthmin_samples_leaf):

from sklearn.model_selection import GridSearchCV# 定义参数网格
param_grid = {"max_depth": [3, 5, 7, 10],"min_samples_leaf": [2, 5, 10],"criterion": ["gini", "entropy"]
}# 网格搜索(5折交叉验证)
grid_search = GridSearchCV(DecisionTreeClassifier(random_state=42), param_grid, cv=5)
grid_search.fit(X_train, y_train)# 输出最优参数和最优分数
print(f"最优参数:{grid_search.best_params_}")
print(f"最优交叉验证准确率:{grid_search.best_score_:.2f}")# 用最优模型预测
best_dt = grid_search.best_estimator_
best_y_pred = best_dt.predict(X_test)
print(f"最优模型测试集准确率:{accuracy_score(y_test, best_y_pred):.2f}")

4.3 结果分析

  • 决策树的可视化结果会显示:每个节点的划分特征(如 “Sex”)、划分阈值(如 “Sex≤0.5”,即女性)、节点的基尼指数、样本数量、类别分布;
  • 通常情况下,“性别(Sex)” 是泰坦尼克号幸存者预测的最重要特征(女性幸存率远高于男性),其次是 “船票等级(Pclass)”(一等舱幸存率高于二等舱、三等舱);
  • 通过预剪枝(如限制max_depth=5),模型测试集准确率可达 75%-85%,泛化能力较好。

五、总结

决策树作为机器学习领域经典且实用的算法,凭借直观易懂、可解释性强的特性,在分类与回归任务中均占据重要地位。本文围绕其原理与实战展开深度剖析,为你搭建从理论理解到实践应用的完整知识链路:

5.1算法演进:从基础到优化

  • ID3:以信息增益为划分依据,开启决策树构建先河,却因对多取值属性的天然偏好,易引发过拟合风险,限制了实际场景的普适性。
  • C4.5:引入信息增益率,通过固有值对信息增益归一化,有效修正 ID3 的属性偏好问题,提升了决策树构建的合理性。
  • CART:支持二叉树结构,分类任务采用基尼指数(计算高效,避免对数运算),回归任务适配均方误差,兼顾效率与泛化能力,成为工业界常用算法。

5.2关键技术:突破应用壁垒

  • 连续值处理:采用排序、确定候选分界点、贪婪选最优的策略,将连续特征离散化,让决策树能适配含连续值的真实数据集,拓宽应用场景。
  • 剪枝策略:预剪枝通过限制树深度、叶子节点样本数等,边构建边干预,计算高效但可能欠拟合;后剪枝先构建完整树,再反向剪枝,依赖验证集评估,虽计算量大,但泛化能力更优,二者配合可有效平衡过拟合与欠拟合。

5.3实战价值:泰坦尼克号预测验证

通过泰坦尼克号幸存者预测实战,完成数据预处理(缺失值填充、分类特征编码 )、模型构建(用DecisionTreeClassifier,结合预剪枝参数调优)、评估(交叉验证、分类报告)全流程。结果显示,决策树能有效捕捉关键特征(如性别、船票等级对幸存率的影响 ),经合理剪枝后,测试集准确率可达 75%-85%,验证了算法在真实场景的应用价值。

从算法原理的层层拆解,到关键技术的逐个突破,再到实战项目的落地验证,决策树展现出强大的可解释性与实用性。掌握它,不仅能助力你解决分类回归问题,更能让你理解 “如何用简单逻辑挖掘数据规律”,为后续进阶复杂模型(如随机森林、梯度提升树 )筑牢基础,在机器学习的探索之路上稳步前行。

http://www.xdnf.cn/news/1328203.html

相关文章:

  • Mac(七)右键新建文件的救世主 iRightMouse
  • 大数据MapReduce架构:分布式计算的经典范式
  • 20250819 强连通分量,边双总结
  • 从线性回归到神经网络到自注意力机制 —— 激活函数与参数的演进
  • 人工智能统一信息结构的挑战与前景
  • 比赛准备之环境配置
  • 进程间的通信1.(管道,信号)
  • LINUX 软件编程 -- 线程
  • 决策树(续)
  • LeetCode100-560和为K的子数组
  • 决策树1.1
  • 项目一系列-第5章 前后端快速开发
  • 项目管理.管理理念学习
  • react-quill-new富文本编辑器工具栏上传、粘贴截图、拖拽图片将base64改上传服务器再显示
  • LeetCode算法日记 - Day 16: 连续数组、矩阵区域和
  • 第4章 React状态管理基础
  • 算法训练营day56 图论⑥ 108. 109.冗余连接系列
  • 项目过程管理的重点是什么
  • Ansible 角色管理
  • 点大餐饮独立版系统源码v1.0.3+uniapp前端+搭建教程
  • GStreamer无线图传:树莓派到计算机的WiFi图传方案
  • GEO 优化专家孟庆涛:技术破壁者重构 AI 时代搜索逻辑
  • RESTful API 开发实践:淘宝商品详情页数据采集方案
  • Apache IoTDB:大数据时代时序数据库选型的技术突围与实践指南
  • 从0到1认识Rust通道
  • Redis-缓存-击穿-分布式锁
  • 无人机场景 - 目标检测数据集 - 山林野火烟雾检测数据集下载「包含VOC、COCO、YOLO三种格式」
  • 国产!全志T113-i 双核Cortex-A7@1.2GHz 工业开发板—ARM + FPGA通信案例
  • 如何免费给视频加字幕
  • Linux的ALSA音频框架学习笔记