当前位置: 首页 > ai >正文

机器学习中决策树

一、简介

1.定义

        决策树是一种属性结构,其中:内部节点:代表对某个特征的判断(特征),分支:代表判断结果,叶子节点:代表最终分类结果(标签)。

2.决策树构建三步骤

        (1)特征选择:筛选对分类 / 回归贡献最大的特征

        (2)决策树生成:基于选定特征递归分裂数据集,生成初步树结构;

        (3)剪枝:缓解过拟合(初步树可能过度贴合训练数据,泛化能力差,需裁剪冗余分支)。

        注:决策树的核心优势是可解释性极强(类似 “if-else” 逻辑,易理解),无需特征标准化(如年龄和收入无需统一量级),但缺点是易过拟合,需依赖剪枝优化。

二、ID3决策树(离散特征专用)

1.核心指标:信息熵与信息增益

(1)信息熵(Entropy)

        定义:信息论中衡量数据不确定性的指标,熵越大,数据混乱度(不确定性)越高;熵越小,数据纯度越高。

        计算公式:对于数据集D,若包含k类样本,各类别占比为p1,p2...pk,则信息熵为:(对数底为 2,单位为 “比特”)

        案例验证:

                数据α(ABCDEFGH):8 类样本,每类占比(1/8),(H(α) = -8×(1/8)log_2(1/8) = 3);

                数据β(AAAABBCD):4 类样本(A:1/2,B:1/4,C:1/8,D:1/8),(H(β) = -(1/2log_2 1/2 + 1/4log_2 1/4 + 2×1/8log_2 1/8) ≈ 1.75);

                结论:(H(α) > H(β)),数据 α 更混乱。

        注:信息熵由香农(Shannon)在 1948 年《通信的数学理论》中提出,最初用于衡量通信中的 “信息不确定性”,后来被引入机器学习作为特征选择的核心指标。

(2)信息增益

        定义:特征a对数据集D的信息增益,等于 “数据集原熵H(D)” 减去 “特征a条件下的条件熵H(D|a)”,代表通过特征a分裂后,数据不确定性减少的程度。

        公式:g(D,a)=H(D)-H(D|a)

        (条件熵H(D|a):按特征α划分后的各子集熵的加权平均,权重为子集占总数据集的比例)

        案例计算(6 个样本:3A、3B,特征a分 α(4 样本:3A1B)、β(2 样本:2B)):

        原熵(H(D) = -(3/6log_2 3/6 + 3/6log_2 3/6) = 1);

        条件熵(H(D|a) = (4/6)×[-(3/4log_2 3/4 + 1/4log_2 1/4)] + (2/6)×[-(2/2log_2 2/2)] ≈ 0.54);

        信息增益g(D,a)=1-0.54=0.46。

2.ID3决策树构建流程

        (1)计算数据集中所有特征的信息增益;

        (2)选择信息增益最大的特征作为当前节点的分裂特征;

        (3)按该特征的取值将数据集拆分为子集;

        (4)对每个子集重复步骤 1-3,直到所有子集纯度达到阈值(如子集全为同一类别)或无特征可分裂。

3.典型案例:论坛客户流失分析

        数据:15 条样本(5 个流失正样本,10 个未流失负样本),特征为 “性别”“活跃度”;

        步骤:计算原熵→分别计算 “性别”“活跃度” 的信息增益→比较得出 “活跃度信息增益更大”,对流失的影响更显著。

        :ID3 的致命缺陷:①仅支持离散特征(无法处理年龄、收入等连续特征);②偏向选择取值数量多的特征(如 “用户 ID” 这类唯一值特征,信息增益极高,但无实际意义),后续 C4.5 决策树专门解决此问题。

三、C4.5决策树(ID3改进版)

1.ID3的痛点与C4.5的改进方向

        ID3痛点:偏向选择取值多的特征(如特征 b 有 6 个取值,特征 a 仅 2 个,ID3 易选 b,但 b 可能无实际预测价值);

        C4.5改进:用信息增益率替代信息增益,引入 “惩罚系数” 修正多取值特征的优势。

2.核心指标:信息增益率

        定义:信息增益率 = 信息增益 / 特征熵(特征熵:以特征a为随机变量的熵,衡量特征自身的不确定性);

        惩罚逻辑:若特征a取值多,其特征熵大,信息增益率会被 “稀释”,从而避免偏向多取值特征;

        案例验证(特征 a:2 取值,特征 b:6 取值):

                特征b的信息增益可能高,但特征熵更大(6 个唯一值,熵≈2.58),信息增益率低;

                特征a的信息增益率更高,最终被选为分裂特征。

3.C4.5的额外优势

        支持连续特征:通过 “离散化” 处理(如将年龄分为 [0-18,19-30,31+]);

        处理缺失值:用 “样本权重” 弥补(如某样本缺失 “年龄”,则按其他特征的分布分配权重);

        剪枝优化:自带后剪枝逻辑,降低过拟合风险。

        注:C4.5 的局限:①计算复杂度高(需多次计算熵和增益率);②无法处理超大数据集(需将数据全部载入内存,不支持分布式),工业界常用 CART 或随机森林(基于 CART)替代。

四、CART决策树(分类+回归双用途)

        CART(Classification and Regression Tree)是最常用的决策树模型,支持分类任务(用基尼指数)和回归任务(用平方损失),且无论特征类型,均生成二叉树(每个节点仅分 2 个分支)

1.CART分类树(预测离散类别)

        (1)核心指标:基尼指数

                定义:从数据集D中随机抽取 2 个样本,其类别标记不一致的概率,取值范围 [0,0.5],值越小,数据纯度越高;

                公式:

                基尼指数计算逻辑:选择使 “分裂后总基尼指数最小” 的特征和分裂点(如离散特征按 “是否为某取值” 分,连续特征按 “是否大于某阈值” 分)。

        (2)典型案例:是否拖欠贷款预测

                数据:10 条样本,特征为 “是否有房”“婚姻状况”“年收入”,目标为 “是否拖欠贷款”;

                计算关键:

                        [是否有房]:有房样本(3 个,全为 “不拖欠”,基尼 = 0),无房样本(7 个,4 不拖欠 3 拖欠,基尼≈0.49),总基尼指数 =(3/10)×0 +(7/10)×0.49≈0.343;

                        [婚姻状况]:按 “是否已婚” 分裂,已婚样本(4 个,全不拖欠,基尼 = 0),未婚样本(6 个,3 不拖欠 3 拖欠,基尼 = 0.5),总基尼指数 =(4/10)×0 +(6/10)×0.5=0.3;

                        [年收入]:按 97.5 为阈值分裂,总基尼指数 = 0.3;

                结论:选择 “婚姻状况(是否已婚)” 或 “年收入(97.5)” 作为分裂特征(基尼指数最小)。

2.CART回归树(预测连续值)

(1)与分类树的核心区别

维度CART 分类树CART 回归树
输出类型离散类别(如 “拖欠 / 不拖欠”)连续值(如 “房价”“收入”)
损失函数基尼指数平方损失(Loss=(f(x)-y)^2)
叶子节点输出子集中多数类别的标签子集中所有样本的均值

(2)CART回归树构建流程

        1.对特征x的取值排序,取相邻值的均值作为候选分裂点(如(x=[1,2,...,10]),候选点为 1.5,2.5,...,9.5);

        2.对每个候选点,将数据分为 “≤分裂点” 和 “>分裂点” 两个子集,计算两个子集的平方损失之和

        3.选择平方损失之和最小的分裂点作为当前特征的最优分裂点;

        4.对每个子集重复步骤 1-3,直到子集满足停止条件(如子集样本数≤阈值)。

(3)典型案例:特征x与目标y的回归

        数据:x=[1-10],(y=[5.56,5.7,5.91,6.4,6.8,7.05,8.9,8.7,9,9.05];

        步骤:

                1.候选分裂点为 1.5-9.5,计算各点平方损失,发现s=6.5时损失最小(m(s)=1.93);

                2.按s=6.5分裂为左子集(x≤6,6 个样本)和右子集(x>6,4 个样本);

                3.对左子集继续分裂,发现s=3.5时损失最小,最终生成二叉树。

        注:CART 的核心优势:①支持分类 + 回归双任务;②二叉树结构简洁,计算效率高;③可处理离散 / 连续特征,工业界应用最广(如随机森林、GBDT 等集成模型的基础均为 CART)。

五、泰塔尼克案例

1.案例背景

        数据:泰坦尼克号乘客数据(特征:Pclass(舱位等级)、Age(年龄)、Sex(性别)等,目标:Survived(是否生存));

        核心逻辑:通过决策树预测乘客是否能生存(历史事实:妇女、儿童、高舱位乘客生存率更高)。

2.实现步骤(Python+sklearn)

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, roc_curve
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# 设置中文字体,解决中文显示问题
plt.rcParams["font.family"] = ["SimHei"]
# 解决负号显示问题
plt.rcParams['axes.unicode_minus'] = False#读取数据
data = pd.read_csv('../data/train.csv')
# data.info()
#数据预处理
# df = data.isnull().sum()
# print(df)
data['Sex'] = data['Sex'].map({'female': 0, 'male': 1})
data['Age'] = data['Age'].fillna(data['Age'].mean())  # 数值列填充
# data.fillna({'Age':data['Age'].mean()}, inplace=True)
data['Embarked'] = data['Embarked'].fillna(data['Embarked'].mode()[0])
data['Embarked'] = data['Embarked'].map({'S':0, 'C':1, 'Q':2})
data.drop('Cabin', axis=1, inplace=True)
data.info()X = data[['Pclass','Sex','Age', 'SibSp', 'Parch','Fare', 'Embarked']]
y = data['Survived']
# X = pd.get_dummies(X['Embarked'], prefix='Embarked')
# print(X.shape)
#特征处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=5999,stratify=y)
# transfer = StandardScaler()
# X_train = transfer.fit_transform(X_train)
# X_test = transfer.transform(X_test)
#创建模型
es = DecisionTreeClassifier()
#训练模型
es.fit(X_train,y_train)
#模型预测
y_predict = es.predict(X_test)
#模型评估
y_proba = es.predict_proba(X_test)[:,1]
print(f"准确值:{accuracy_score(y_test, y_predict)}")
print(f"召回率:{recall_score(y_test, y_predict)}")
print(f"精准率:{precision_score(y_test, y_predict)}")
print(f"f1分数:{f1_score(y_test, y_predict)}")
print(f"AUC指标:{roc_auc_score(y_test, y_proba)}")plt.figure(figsize=(8,4))
plot_tree(es,filled=True,max_depth=30)
plt.show()

3.关键结论

        重要特征:Sex(性别)>Pclass(舱位)>Age(年龄),符合 “妇女优先、高舱位优先” 的历史事实;

        调参建议:通过max_depth(树深)、min_samples_leaf(叶子节点最小样本数)限制树的复杂度,降低过拟合。

        注:实际项目中,决策树的调参核心是控制复杂度:①max_depth不宜过大(如超过 10 易过拟合);②min_samples_split(内部节点分裂最小样本数)设为 5-10,避免小样本分裂;③ccp_alpha(成本复杂度剪枝参数)可自动剪枝,sklearn 推荐使用。

六、决策树剪枝(防止过拟合的核心手段)

1.剪枝的必要性

        决策树若不剪枝,会过度贴合训练数据(如每个叶子节点仅 1 个样本),导致泛化能力差(测试集准确率低),剪枝通过 “删除冗余分支”,保留核心逻辑,提升泛化能力。

2.两种剪枝方法对比

剪枝类型核心逻辑优点缺点
预剪枝树生成过程中,对每个节点先判断:若分裂后验证集准确率无提升,则停止分裂,标记为叶子节点训练 / 测试效率高,避免冗余分支生成可能欠拟合(过早停止分裂,错过后续有效分支)
后剪枝先生成完整决策树,再自底向上遍历非叶节点:若将子树替换为叶子节点后准确率提升,则剪枝泛化能力强,欠拟合风险低训练效率低(需生成完整树再剪枝)

3.案例:西瓜数据集剪枝

        预剪枝:生成 “脐部 =?” 节点后,若 “色泽 =?” 分裂后验证集准确率从 71.4% 降至 57.1%,则禁止分裂,保留 “脐部” 节点;

        后剪枝:先生成含 6 个内部节点的完整树,再判断 “色泽 =?”“纹理 =?” 等节点,若剪枝后准确率提升(如从 57.1%→71.4%),则剪枝。

        注:工业界常用后剪枝 + 预剪枝结合:①先用预剪枝快速生成基础树;②再用后剪枝优化局部分支;③sklearn 中DecisionTreeClassifierccp_alpha参数实现 “成本复杂度剪枝”(后剪枝的一种),通过最小化 “训练损失 +α× 树复杂度” 选择最优树。

七、三大分类决策树对比

决策树类型提出时间分支方式支持特征类型核心优势核心局限
ID31975信息增益仅离散计算简单,易理解不支持连续 / 缺失值,偏多取值特征
C4.51993信息增益率离散 + 连续解决 ID3 偏倚,支持缺失值计算复杂,不支持大数据集
CART1984基尼指数(分类)离散 + 连续支持分类 / 回归,二叉树高效对异常值敏感
平方损失(回归)

八、核心知识点梳理

1.决策树定义:内部节点是特征判断,叶子是分类结果,可明确特征重要性;

2.信息熵:熵越大,混乱度越高

3.信息熵增率:缓解多取值特征偏倚,C4.5 核心

4.基尼指数:CART 分类树核心,值越小纯度越高

5.剪枝方法:预剪枝(生成时剪)、后剪枝(生成后剪),均为防过拟合

http://www.xdnf.cn/news/19970.html

相关文章:

  • LeetCode 48 - 旋转图像算法详解(全网最优雅的Java算法
  • 安全与效率兼得:工业控制系统如何借力数字孪生实现双赢?
  • CPTS-Manager ADCS ESC7利用
  • HTML图片标签及路径详解
  • 代码随想录训练营第三十一天|LeetCode56.合并区间、LeetCode738.单调递增的数字
  • freertos下printf(“hello\r\n“)和printf(“hello %d\r\n“,i)任务堆栈消耗有何区别
  • 金贝 KA Box 1.18T:一款高效能矿机的深度解析
  • Python 第三方自定义库开发与使用教程
  • Redis是单线程的,为啥那么快呢?经典问题
  • 第六章 Cesium 实现简易河流效果
  • 热计量表通过M-Bus接口实现无线集抄系统的几种解决方
  • 2025国赛C题题目及最新思路公布!
  • ubuntu20.04配置运行ODM2.9.2教程,三维重建,OpenDroneMap/ODM2.9.2
  • 智能家居芯片:技术核心与创新突破
  • Spring Cloud Ribbon 核心原理
  • 数字化转型:从锦上添花到生存必需——2025年零售行业生存之道
  • Function Call实战:用GPT-4调用天气API,实现实时信息查询
  • Matlab中的积分——函数int()和quadl()
  • PDF24 Creator:免费的多功能PDF工具
  • OPC UA双层安全认证模型解析
  • 【蓝桥杯选拔赛真题64】C++最大空白区 第十四届蓝桥杯青少年创意编程大赛 算法思维 C++编程选拔赛真题解
  • 大小端存储的理解与判断方法
  • Cypress 测试框架:轻松实现端到端自动化测试!
  • 从零开始的python学习——元组
  • PostgreSQL与SQL Server:B树索引差异及去重的优势
  • Webus 与中国国际航空合作实现 XRP 支付
  • DeepSeek文献太多太杂?一招制胜:学术论文检索的“核心公式”与提问艺术
  • Java+Vue构建的MES智能管理系统,集成生产计划、执行、监控与优化功能,支持产品、车间、工艺、客户、供应商等多维度管理,含完整源码,助力企业高效生产
  • LeetCode算法日记 - Day 31: 判定是否互为字符重排、存在重复元素
  • nextcyber——常见应用攻击