当前位置: 首页 > news >正文

机器学习 入门——决策树分类

决策树是一种直观且强大的机器学习算法,适用于分类和回归任务。本文将全面介绍决策树分类的原理、实现、调优和实际应用。

一、什么是决策树分类

1.概念

决策树分类是一种树形结构的分类模型,它通过递归地将数据集分割成更小的子集来构建决策规则。就像我们日常生活中做决策一样(例如:如果天气晴朗,就去公园;否则在家看电影),决策树通过一系列的判断条件来对数据进行分类。下图为一个决策树

2.构建过程

  1. ​特征选择​​:

    • 使用指标(如信息增益、增益率或基尼指数)选择最佳分裂特征。

    • ​信息增益​​(ID3算法):选择使信息熵下降最多的特征。

    • ​增益率​​(C4.5算法):解决信息增益对多值特征的偏好问题。

    • ​基尼指数​​(CART算法):衡量数据不纯度,值越小纯度越高。

  2. ​节点分裂​​:

    • 根据特征的阈值(连续值)或类别(离散值)将数据划分为子集。

    • 递归处理子集,直到满足停止条件。

  3. ​剪枝策略​​(防止过拟合):

    • ​预剪枝​​:在分裂前评估,若增益不足则停止分裂。

    • ​后剪枝​​:先构建完整树,再自底向上剪去不重要的分支。

二、决策树的分类标准

1、信息增益(Information Gain)

1. 核心概念​

​(1)熵(Entropy)​

  • ​定义​​:衡量数据集的不确定性(混乱程度)。熵越大,数据越无序。

    • 公式:

      
      
      • S:当前数据集;

      • pi:类别 i在数据集中的比例;

      • c:类别总数。

  • ​例子​​:

    • 若数据集全是同一类别(如全为“是”),熵为0(完全确定)。

    • 若类别均匀分布(如“是”“否”各占50%),熵为1(最大不确定性)。

天气

温度

湿度

风力

是否出去玩

多云

正常

上述列表中类别分布为3个“是”,2个“否”。

所以信息熵为

​(2)信息增益(Information Gain)​

  • ​定义​​:划分前后熵的减少量,反映属性对分类的贡献。

    • 公式:

      
      
      • A:候选属性;

      • Values(A):属性 A的所有可能取值;

      • Sv​:属性 A取值为 v的子集。

  • ​目标​​:选择使 IG(S,A)最大的属性 A。

对于上述数据,可以计算出每个属性的信息增益以天气为例

  • 取值​​:晴(2条)、多云(1条)、雨(2条)。

  • 子集熵计算:

    • 晴:2条全为“否” → Entropy(S晴​)=0。

    • 多云:1条全为“是” → Entropy(S多云​)=0。

    • 雨:2条全为“是” → Entropy(S雨​)=0。

  • 信息增益:

    
    (同理可计算其他属性的信息增益,选择最大的作为划分节点。)

2. 构建决策树

通过计算信息增益我们就可​以构建决策树。​

根节点划分(天气)​

  • ​天气 = 晴​​:2条数据,全为“否” → ​​叶节点(否)​​。

  • ​天气 = 多云​​:1条数据,全为“是” → ​​叶节点(是)​​。

  • ​天气 = 雨​​:2条数据,全为“是” → ​​叶节点(是)​​。

此时决策树已完全分类,无需进一步划分(所有子集纯度100%)。

但若假设“雨”的子集不纯(例如有“否”),则需继续划分其他属性。

 2、信息增益比

信息增益比是​​对信息增益的改进​​,用于解决信息增益对多值属性的偏好问题。

通过引入属性的​​固有值(Intrinsic Value)​​,惩罚取值较多的属性,从而平衡划分标准。

公式​​:

其中:

  • IntrinsicValue(A)=

  • Values(A):属性 A的所有取值,Sv​是取值为 v的子集。

  • InformationGain(A)为a属性的信息增益

 计算步骤(示例)​

沿用之前的天气数据集,但假设“湿度”有更多取值以演示效果:

天气

湿度(新)

是否出去玩

80%

85%

多云

90%

75%

60%

​Step 1: 计算信息增益(IG)​

  • 对属性“湿度”(连续属性需离散化,假设分为高/正常):

    • 高(80%, 85%, 90%, 75%):3否,1是 → 熵 ≈ 0.811

    • 正常(60%):1是 → 熵 = 0

    • IG(湿度)=0.971−(54​×0.811+51​×0)≈0.322

​Step 2: 计算固有值(IV)​

  • 湿度取值分布:高(4条)、正常(1条)。

    
    

​Step 3: 计算增益比​

GainRatio(湿度)=0.7220.322​≈0.446

3、GINI系数

基尼系数是决策树(如CART算法)中用于衡量数据​​不纯度​​的指标,表示从数据集中随机抽取两个样本,其类别标签不一致的概率。

GINI系数

  • ​公式​​:

    • S:当前数据集;

    • pi​:类别 i在数据集中的比例;

    • c:类别总数。

  • ​范围​​:0(完全纯净)到 0.5(均匀分布的两类)

特征A条件下的加权基尼系数​

​公式​​:


  • Values(A):特征A的所有可能取值(如“天气”取值为晴、雨、多云)。

  • Sv​:特征A取值为 v的子数据集。

  • S:特征A的数据集数

  • Gini(Sv​):子集 Sv​的基尼系数。​

假设银行根据以下特征决定是否批准贷款申请:

年龄

收入

学历

是否有房产

是否批准贷款

青年

高中

青年

高中

青年

本科

中年

本科

中年

硕士

老年

硕士

​目标​​:预测“是否批准贷款”

​特征​​:年龄、收入、学历、是否有房产

​Step 1: 计算初始基尼系数​

  • 类别分布:3“否”,3“是”。

  • 初始基尼系数:

    Gini(S)=1−((63​)2+(63​)2)=1−(0.25+0.25)=0.5

Step 2: 计算各特征的基尼增益​

​(1)特征:年龄​

  • ​取值​​:青年(3条)、中年(2条)、老年(1条)。

  • ​子集基尼系数​​:

    • 青年:3条(全“否”)→ Gini=1−(1^2+0^2)=0

    • 中年:2条(1“否”,1“是”)→ Gini=1−(0.5^2+0.5^2)=0.5

    • 老年:1条(全“是”)→ Gini=0

  • ​加权基尼系数​​:

    Ginisplit​(年龄)=1/2​×0+62​×0.5+1/6​×0≈0.167

(2)特征:是否有房产​

  • ​取值​​:有(3条)、无(3条)。

  • ​子集基尼系数​​:

    • 有:3条(1“否”,2“是”)→ Gini=1−(3/1​)^2−(3/2​)^2≈0.444

    • 无:3条(2“否”,1“是”)→ Gini≈0.444

  • ​加权基尼系数​​:

    Ginisplit​(房产)=1/2​×0.444+1/2​×0.444≈0.444

4、决策树剪枝(Pruning)

1. 剪枝的目的​

决策树容易过拟合(Overfitting),当你的数据量过大时,会导致树深度过大或节点过多。剪枝通过​​移除部分分支或子树​​,简化模型结构,提升泛化能力(指模型在​​未见过的数据​​上表现良好的能力,即从训练数据中学到的规律能否推广到新样本。)。

  • ​核心目标​​:在训练集准确性和测试集泛化性之间取得平衡。

​2. 剪枝方法分类​

​(1)预剪枝(Pre-Pruning)​

在决策树构建过程中提前停止生长,通过设定阈值限制树的复杂度。

预剪枝就像​​给树苗修剪枝叶​​,在决策树生长过程中提前阻止不必要的分支。通过设定规则限制树的复杂度,防止它长得"太茂盛"(过拟合)。

预剪枝的常见方法​

​(1) 限制树的高度(max_depth)​
  • ​作用​​:控制树的最大层数,避免决策规则过于复杂。

  • ​例子​​:贷款审批时,最多只问3个问题(如年龄→收入→房产),再多就拒绝(防止过度追问隐私)。

​(2) 设置节点最小样本数(min_samples_split)​
  • ​作用​​:节点至少需要多少样本才允许继续分裂。

  • ​例子​​:医生诊断时,至少要有10个相似病例才新增检查项,否则按经验开药。

​(3) 信息增益阈值(min_impurity_decrease)​
  • ​作用​​:只有分裂能显著提升分类效果时才允许分裂。

  • ​例子​​:挑西瓜时,如果"听声音"和"看颜色"判断效果差不多,就只用其中一个特征。

三、python实战

1.sklearn.tree.DecisionTreeClassifier()参数

sklearn.tree.DecisionTreeClassifier(criterion="gini",    splitter="best",    max_depth=None,   min_samples_split=2,  min_samples_leaf=1,   min_weight_fraction_leaf=0.0,  max_features=None,  random_state=None,   max_leaf_nodes=None, min_impurity_decrease=0.0,   class_weight=None, ccp_alpha=0.0,)

DecisionTreeClassifier是 scikit-learn 提供的分类决策树模型,适用于离散类别预测(如垃圾邮件分类、疾病诊断)。

1.核心分裂参数​

参数

说明

推荐值

引用

criterion

分裂质量评估标准:
'gini'(默认):基尼系数,计算更快
'entropy':信息增益,对不纯度更敏感但可能过拟合

高维数据用 'gini',低维清晰数据两者差异小

splitter

分裂策略:
'best'(默认):全局最优分裂
'random':局部随机分裂,适合大数据加速训练

小数据用 'best',大数据用 'random'


​2. 剪枝与复杂度控制​

参数

说明

推荐值

引用

max_depth

树的最大深度。None表示不限,但易过拟合

通常设为 3-10,通过交叉验证选择

min_samples_split

节点继续分裂的最小样本数:
- 整数:绝对数量
- 浮点数:占总样本比例

样本量大时建议 ≥10 或 0.01-0.1

min_samples_leaf

叶节点最小样本数,防止噪声干扰

分类任务建议 ≥5

max_leaf_nodes

限制叶节点总数,优先于 max_depth

特征多时设为 10-100

min_impurity_decrease

分裂需达到的最小不纯度减少量

0-0.1,值越大树越简单


​3. 特征与随机性控制​

参数

说明

推荐值

引用

max_features

分裂时考虑的最大特征数:
'sqrt'(默认):特征数的平方根
'log2':log2(特征数)
- 整数/浮点数:固定数量/比例

高维数据用 'sqrt'或 'log2'

random_state

随机种子,保证结果可复现

固定值如 42


​4. 类别不平衡处理​

参数

说明

推荐值

引用

class_weight

类别权重:
None(默认):等权重
'balanced':自动按类别频率反比加权

类别不平衡时用 'balanced'


​5. 其他实用参数​

参数

说明

推荐值

引用

ccp_alpha

代价复杂度剪枝参数,后剪枝强度

通过交叉验证选择

presort

预排序数据以加速训练(已弃用)

不推荐使用

 2.实例——客户历史流失数据

对于600多条数据,一共20个属性,两种状态,为二分类问题

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split as trtesp
#sklearn自带的拆分数据集的函数train_test_split
#导入需要用到的库data = pd.read_excel('电信客户流失数据.xlsx')x = data.iloc[:, :-1]
y = data.iloc[:, -1]x_train,x_test,y_train,y_test = \trtesp(x, y, test_size=0.2, random_state=100)#拆分数据集lr = DecisionTreeClassifier()
lr.fit(x_train, y_train)from sklearn.model_selection import cross_val_score#导入打印召回率的函数
#定义循环列表定义最大深度,分裂的最小样本,叶节点最小样本数,最大叶子节点总数范围
max_depth = [i for i in range(4,13)]
min_samples_split = [i for i in range(2,10)]
min_samples_leaf = [i for i in range(12,19)]
max_leaf_nodes = [i for i in range(9,16)]scores = 0
best = []#定义最佳参数存放列表#利用循环嵌套寻找最佳参数
for depth in max_depth:for min_samples in min_samples_split:for min_leaf in min_samples_leaf:for max_leaf in max_leaf_nodes:lr = DecisionTreeClassifier(criterion='gini',max_depth=depth,min_samples_split=min_samples,min_samples_leaf=min_leaf,max_leaf_nodes=max_leaf,random_state=42)#进行交叉验证,评估模型的泛化能力score = /cross_val_score(lr, x_train, y_train, cv=8,scoring='recall')scores_m = sum(score)/len(score)#计算分数if scores_m > scores:#统计最大分数scores= scores_mbest = [depth,min_samples,min_leaf,max_leaf]print('最佳惩罚因子为:',best[:])
#训练最佳参数模型
lr = DecisionTreeClassifier(criterion='gini',max_depth=best[0],min_samples_split=best[1],min_samples_leaf=best[2],max_leaf_nodes=best[3],random_state=42)
lr.fit(x_train, y_train)
y_spred = lr.predict(x_train)
y_pred = lr.predict(x_test)
from sklearn import metrics
print('自测:',metrics.classification_report(y_train, y_spred))#获取自测报告
print('测试:',metrics.classification_report(y_test, y_pred))#获得测试集测试报告#输出决策树图像
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
fig,ax = plt.subplots(figsize=(10,10))
plot_tree(lr,filled=True,ax=ax)
plt.show()

http://www.xdnf.cn/news/1242703.html

相关文章:

  • 并发编程常用工具类(下):CyclicBarrier 与 Phaser 的协同应用
  • C++入门自学Day6-- C++模版
  • 飞算JavaAI需求转SpringBoot项目沉浸式体验
  • 【BUUCTF系列】[极客大挑战 2019]LoveSQL 1
  • vllm启动Qwen/Qwen3-Coder-30B-A3B-Instruct并支持工具调用
  • MLIR Introduction
  • android内存作假通杀补丁(4GB作假8GB)
  • History 模式 vs Hash 模式:Vue Router 技术决策因素详解
  • ZYNQ-按键消抖
  • JavaScript 中的流程控制语句详解
  • 3.JVM,JRE和JDK的关系是什么
  • 第二十四天(数据结构:栈和队列)队列实践请看下一篇
  • SQL注入SQLi-LABS 靶场less39-50详细通关攻略
  • 基于实时音视频技术的远程控制传输SDK的功能设计
  • 【ECCV2024】AdaCLIP:基于混合可学习提示适配 CLIP 的零样本异常检测
  • [GESP202306 四级] 2023年6月GESP C++四级上机题超详细题解,附带讲解视频!
  • 刷题记录0804
  • ref和reactive的区别
  • 8位以及32位的MCU如何进行选择?
  • ArrayDeque双端队列--底层原理可视化
  • Redis 常用数据结构以及单线程模型
  • LeetCode 140:单词拆分 II
  • Array容器学习
  • app-1
  • 优选算法 力扣 11. 盛最多水的容器 双指针降低时间复杂度 贪心策略 C++题解 每日一题
  • Javascript面试题及详细答案150道之(031-045)
  • python包管理器uv踩坑
  • 力扣面试150题--加一
  • PCL统计点云Volume
  • ArcGIS的字段计算器生成随机数