当前位置: 首页 > news >正文

【机器学习-3】 | 决策树与鸢尾花分类实践篇

0 序言

本文将深入探讨决策树算法,先回顾下前边的知识,从其基本概念、构建过程讲起,带你理解信息熵、信息增益等核心要点。

接着在引入新知识点,介绍Scikit - learn 库中决策树的实现与应用,再通过一个具体项目的方式来帮助你掌握决策树在分类和回归任务中的使用,提升机器学习实践能力。

【项目在下文第四节,原理已经搞懂了前3节可跳过!!!】

1 决策树算法概述

1.1 决策树的基本概念

决策树是一种常见的机器学习算法,通过树状结构对数据进行分类回归模仿人类决策过程,利用一系列判断条件逐步划分数据。

其主要组成元素包括:

  • 根节点(Root Node):包含完整数据集的起始节点,是决策树的开端,所有数据最初都从这里开始进行划分。

  • 内部节点(Internal Node):表示一个特征或属性,用于对数据进行进一步的判断和划分。例如在预测水果类别时,内部节点可能是 “颜色” 这个特征。

  • 分支(Branch):表示决策规则,基于内部节点特征的不同取值而产生不同的分支。比如在颜色节点下,可能有红色、绿色等分支。

  • 叶节点(Leaf Node):表示决策结果,当数据经过一系列内部节点的判断和分支后,最终到达叶节点,得到相应的分类或回归值。如经过多个特征判断后,叶节点给出水果是苹果的结论。

也可参照下图作进一步理解。

在这里插入图片描述

1.2 决策树的优缺点

1.2.1 优点

  • 易于理解和解释:决策树以直观的树状结构呈现,其规则可清晰解读,方便人们理解模型的决策逻辑。例如,通过观察树结构,能直接明白根据哪些特征如何做出最终决策。

  • 较少的数据预处理需求:无需对数据进行复杂的归一化标准化处理,可直接处理数值型和类别型数据。像在处理客户购买行为数据时,年龄(数值型)和性别(类别型)可同时作为特征输入决策树。

  • 能处理多输出问题:可同时对多个目标变量进行预测,适用于复杂的多任务场景。如在预测天气状况时,可同时预测温度、湿度、天气类型等多个指标。

1.2.2 缺点

  • 容易过拟合:决策树可能过度学习训练数据中的细节和噪声,导致对新数据的泛化能力差。例如,训练数据中存在一些偶然因素,决策树可能将其当作普遍规律学习,在预测新数据时出错。

  • 对数据中的噪声敏感:数据中的噪声可能误导决策树的构建,影响模型准确性。比如错误标注的数据可能导致决策树生成不合理的分支。

  • 可能偏向大量级别的属性:在构建决策树时,算法可能倾向于选择取值较多的属性进行划分,而这些属性不一定对分类或回归最有价值。

2 决策树的构建过程

2.1 特征选择

特征选择是决策树构建的核心环节,目的是选择最佳划分特征,使划分后的数据集纯度更高。

常用的特征选择标准如下:

  • 信息增益(Information Gain):基于信息熵的减少量来选择特征。信息熵用于度量样本集合的纯度,信息熵越小,集合纯度越高。ID3 算法采用信息增益。例如,对于一个包含是否购买商品标签的数据集,计算年龄、性别等特征的信息增益,选择信息增益最大的特征作为当前节点的划分特征,这样能最大程度降低数据集的不确定性。

  • 信息增益比(Gain Ratio):对信息增益进行改进,考虑了特征本身的熵。它能减少信息增益对取值数目较多属性的偏好,C4.5算法使用信息增益比。例如,在某些数据集中,身份证号这种取值众多的特征信息增益可能很大,但对分类并无实际意义,信息增益比可避免这种情况。

  • 基尼指数(Gini Index)衡量数据集的不纯度,反映从数据集中随机抽取两个样本,其类别标记不一致的概率。CART 算法使用基尼指数。基尼指数越小,数据集纯度越高。

2.2 决策树的生成

决策树的生成是一个递归过程,主要步骤如下:

  1. 从根节点开始:此时根节点包含全部数据集。例如,有一个预测动物类别的数据集,初始时所有数据都在根节点。

  2. 计算所有可能的特征划分:针对根节点的数据集,计算每个特征不同取值下的划分情况。比如对于动物数据集,计算是否有翅膀、是否有毛发等特征不同取值时对数据集的划分效果。

  3. 选择最佳划分特征:根据信息增益、信息增益比或基尼指数等标准,选择能使划分后数据集纯度最高的特征。假设通过计算信息增益,发现是否有毛发这个特征划分后的数据集纯度提升最大,则选择它作为当前节点的划分特征。

  4. 根据该特征的取值创建子节点:根据所选特征的不同取值,将数据集划分到不同的子节点。如 是否有毛发特征有两个取值,数据就被分到两个子节点,每个子节点包含相应取值的数据子集。

  5. 对每个子节点递归地重复上述过程:对每个子节点的数据子集,再次进行特征划分、选择最佳特征、创建新子节点等操作,直到满足停止条件。停止条件通常包括节点中的样本全部属于同一类别、达到最大深度或节点中样本数量过少等。例如,某个子节点的数据集中所有动物都属于猫科动物类别,此时该子节点就成为叶节点,不再继续划分。

2.3 决策树的剪枝

为防止决策树过拟合,需要进行剪枝操作。剪枝分为预剪枝后剪枝

  • 预剪枝(Pre - pruning):在树构建过程中提前停止。例如,在构建决策树时,如果某个节点的划分不能带来决策树泛化能力的提升(如通过验证集评估),则停止划分,将该节点标记为叶节点。还可以通过设置一些参数来实现预剪枝,如限制叶子节点的样本个数,当样本个数小于一定阈值时,不再继续创建分支;或者设定信息熵减小的阈值,当信息熵减小量小于该阈值时,停止创建分支。

  • 后剪枝(Post - pruning)先构建完整树,然后剪去不重要的分支。从已经构建好的完全决策树的底层开始,对非叶节点进行考察,若将该节点对应的子树替换成叶节点可以带来决策树泛化性能的提升(如在验证集上错误率降低),则将该子树替换为叶节点。例如,有一棵完整的决策树,对某个非叶节点进行评估,发现将其下面的子树替换为一个叶节点后,在验证集上的准确率提高了,就进行此替换操作。

这里如果单独看觉得比较抽象,可以看以下这个图。

还是前面那张图,假如为剪枝前的,

在这里插入图片描述

那么剪枝后就可能是这样子。

在这里插入图片描述

就相当于说,在不损失算法的性能的前提下,将比较复杂的决策树,化简为较为简单的版本。

3 Scikit - learn中的决策树 API

Scikit - learn提供了两个主要的决策树实现:

  • DecisionTreeClassifier:用于分类问题,通过构建决策树对数据进行分类预测。例如,预测客户是否会购买某产品、邮件是否为垃圾邮件等。

  • DecisionTreeRegressor:用于回归问题,通过决策树模型对连续型目标变量进行预测。比如预测房价、股票价格等。

3.1 DecisionTreeClassifier 参数详解

class sklearn.tree.DecisionTreeClassifier(criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0
)
  • criterion:衡量分割质量的函数,可选值有:

  • “gini”:基于基尼不纯度,计算概率分布的基尼系数,反映数据集的不纯度。

  • “entropy”:基于信息增益,通过计算信息熵的减少量来衡量分割质量。

  • “log_loss”:从 v1.3 版本开始新增,基于对数损失。默认值为gini。例如,如果数据集中类别分布较为均匀,使用entropy可能更合适,能更好地衡量信息增益。

  • splitter:选择每个节点分割的策略,可选值有:

  • “best”:选择最佳分割,通过计算所有可能的分割情况,找到使分割后数据集纯度最高的方案。

  • “random”:随机选择分割,这种方式可能更快,但结果不一定是最优的,可用于减少计算量或防止过拟合。默认值为best

  • max_depth:树的最大深度,默认值为 None。当为 None 时,树会一直扩展,直到所有叶子都是纯的(即叶子节点中的样本都属于同一类别)或包含少于 min_samples_split 个样本。如果设置为整数,如 5,则限制树的最大深度为 5 层,可有效防止过拟合。

  • min_samples_split:分割内部节点所需的最小样本数,默认值为 2。可以是整数,表示绝对数量;也可以是浮点数,表示占总样本的比例。例如,设置为 5,表示内部节点至少需要 5 个样本才会进行分割;设置为 0.1,表示内部节点样本数占总样本数的比例至少为 10% 时才进行分割。

  • min_samples_leaf:叶节点所需的最小样本数,默认值为 1。同样可以是整数或浮点数。例如,设置为 3,可防止创建样本数过少的叶节点,避免模型过拟合。

  • min_weight_fraction_leaf叶节点所需的权重总和的最小加权分数,默认值为 0.0。当提供了 sample_weight(样本权重)时使用,与 min_samples_leaf 二选一。例如,设置为 0.1,表示叶节点权重和必须大于等于总权重的 10%。

  • max_features:寻找最佳分割时要考虑的特征数量,默认值为 None。可选值有:

  • None / 不设置:考虑所有特征。

  • “auto”/“sqrt”:表示考虑 <inline_LaTeX_Formula>\sqrt {n_features}<\inline_LaTeX_Formula > 个特征,其中 < inline_LaTeX_Formula>n_features<\inline_LaTeX_Formula > 是特征总数。

  • “log2”:表示考虑 <inline_LaTeX_Formula>log2 (n_features)<\inline_LaTeX_Formula > 个特征。

  • 整数:表示绝对数量的特征。

  • 浮点数:表示占总特征的比例。该参数可影响训练速度和模型随机性,例如设置为 “sqrt”,可减少计算量,提高训练速度。

  • random_state:控制随机性,默认值为 None。当为 None 时,使用随机数生成器的默认状态;当为整数时,作为随机数生成器的种子,确保结果可重现。例如,设置 random_state = 42,多次运行代码可得到相同的随机结果,方便调试和对比实验。

  • max_leaf_nodes:最大叶节点数量,默认值为 None。当为 None 时,不限制叶节点数量;当为整数时,以最佳优先方式生长树,可作为替代 max_depth 的剪枝方法。例如,设置为 10,可限制决策树最多有 10 个叶节点,防止树生长过大导致过拟合。

  • min_impurity_decrease:分割节点需要的最小不纯度减少量,默认值为 0.0。计算公式为 <inline_LaTeX_Formula>\frac {N_t}{N} \times (impurity - \frac {N_t_R}{N_t} \times right_impurity - \frac {N_t_L}{N_t} \times left_impurity)<\inline_LaTeX_Formula>,其中 < inline_LaTeX_Formula>N<\inline_LaTeX_Formula > 是总样本数,<inline_LaTeX_Formula>N_t<\inline_LaTeX_Formula > 是当前节点样本数。当某个节点的不纯度减少量大于等于该值时,才会进行分割。

  • class_weight:类别权重,默认值为 None。可选值有:

  • None:所有类别权重为 1,即不考虑类别不平衡问题。

  • “balanced”:自动计算权重,与类别频率成反比,用于处理类别不平衡问题。例如,某个类别样本数量很少,其权重会相应增大。

  • 字典:手动指定 {class_label: weight},可根据实际情况为每个类别设置权重。

  • ccp_alpha最小成本复杂度剪枝参数,默认值为 0.0。当为 0.0 时,默认不剪枝;当大于 0.0 时,较大的值会导致更多剪枝,可通过交叉验证选择最优值。例如,通过设置不同的 ccp_alpha 值,在验证集上评估模型性能,选择使模型泛化能力最佳的 ccp_alpha 值。

3.2 DecisionTreeRegressor 参数详解

DecisionTreeRegressorDecisionTreeClassifier的参数大部分相同,但criterion参数有所不同:

class sklearn.tree.DecisionTreeRegressor(criterion='squared_error', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, ccp_alpha=0.0
)
  • criterion:衡量分割质量的标准,默认值为squared_error(均方误差),可选值还有:

  • “absolute_error”:平均绝对误差。

  • “friedman_mse”:弗里德曼均方误差,考虑了数据集的特征数量和样本数量对误差的影响。

  • “poisson”泊松偏差,适用于目标变量服从泊松分布的数据。不同的 criterion 适用于不同的数据分布和问题场景,例如,当数据中存在较多异常值时,absolute_error可能比squared_error更稳健。

4 决策树实战示例

4.1 前期准备

导入需要的4个库。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report, accuracy_score

pandas库主要用于数据的结构化处理,比如数据清洗就要用到它;

numpy库则是用于数值计算和数组操作。

matplotlib的pyplot模块是用于绘制基础图表;

seaborn库是为了绘制更美观的统计图表。

下面这四条程序与机器学习有关。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import classification_report, accuracy_score
  • from sklearn.datasets import load_iris:从scikit-learn库导入鸢尾花数据集

  • from sklearn.model_selection import train_test_split:用于将数据集分割为用于模型训练的训练集用于模型评估的测试集

  • from sklearn.tree import DecisionTreeClassifier, plot_tree

    • DecisionTreeClassifier:决策树分类器类,用于构建决策树模型
    • plot_tree:可视化决策树结构的函数
  • from sklearn.metrics import classification_report, accuracy_score

    • accuracy_score:计算模型预测的准确率
    • classification_report:生成详细的分类评估报告,包含精确率、召回率、F1分数等指标

这些库组合起来,就是前期需要完成的准备工作。

4.2 加载并探索数据集

我们在拿到一个全新的数据集之前,由于一开始对数据集并不熟悉,

它到底有多少标签,如何划分,数据结构如何,

这些都是未知的,因此本小节的内容就很重要,

因为它是让我们熟悉数据集,

只有充分熟悉数据集,你才能更好地去使用它作为数据支撑去训练你的模型。

# Load iris dataset
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['species'] = iris.target# Map numeric targets to class labels
df['species'] = df['species'].map({i: name for i, name in enumerate(iris.target_names)})# Show dataset info
print("Dataset shape:", df.shape)
df.head()

下面针对以上程序进一步分析讲解。

1. 加载鸢尾花数据集

iris = load_iris()
  • load_iris() 是 sklearn 提供的加载鸢尾花数据集的函数
  • 返回的 iris 是一个类似字典的对象,包含数据集的各种信息(特征数据、标签、特征名称等)

2. 构建DataFrame

df = pd.DataFrame(iris.data, columns=iris.feature_names)
  • 将鸢尾花的特征数据(iris.data)转换为 pandas 的 DataFrame 格式,方便后续处理
  • columns=iris.feature_names 为 DataFrame 设置列名,这些列名是鸢尾花的4个特征名称:
    • sepal length (cm)(花萼长度)
    • sepal width (cm)(花萼宽度)
    • petal length (cm)(花瓣长度)
    • petal width (cm)(花瓣宽度)

我在网上找了一张图,可以参考这张图标的参数具体在现实中花朵属于哪个部位。

在这里插入图片描述
另外,

这里转成DataFrame格式是为了给每列数据都加上具体的标签,以便于后续数据处理,

原先数据集是个二维数组,只有纯数字并没有标签。

3. 添加标签列并映射为类别名称

df['species'] = iris.target
df['species'] = df['species'].map({i: name for i, name in enumerate(iris.target_names)})
  • 第一行:添加标签列 speciesiris.target 是原始的数值标签(0、1、2)
  • 第二行:将数值标签映射为实际的鸢尾花品种名称:
    • 0 → setosa(山鸢尾)
    • 1 → versicolor(变色鸢尾)
    • 2 → virginica(维吉尼亚鸢尾)
  • 映射使用字典推导式,将 iris.target_names 中的品种名称与索引对应起来。

4. 展示数据集信息

print("Dataset shape:", df.shape)
df.head()
  • print("Dataset shape:", df.shape) 输出数据集的形状(样本数×特征数)
  • df.head() 显示数据集的前5行数据。

一般的图表,第一行基本都会写上标签,

我们快速了解表的结构和内容一般看头五行基本就大概了解了,

我这里使用的数据集,鸢尾花数据集固定为 (150, 5)(150个样本,4个特征+1个标签)

运行以上程序,初步认识下数据集。

在这里插入图片描述
这里就清楚了,

通过这个结构化的鸢尾花数据集表格,

就清楚它包含4个特征列1个品种标签列

那我们后续的数据分析和模型训练就需要以它为标准。

4.3 可视化鸢尾花数据集特征分布

接着我们用图表的形式来具体看看该数据集的特征分布,

同理,本小节的方法你用在其他数据集上也是同样的道理,

可以说是必不可少的一步。

sns.pairplot(df, hue='species')
plt.suptitle("Feature Distribution by Species", y=1.02)
plt.show()

先对以上程序进行简单的分析:

这里主要是借助 seabornpairplot 函数,从多个维度展示特征间关系与分布规律


1. 核心:sns.pairplot(df, hue='species')

  • sns.pairplot
    seaborn 库的配对图函数,自动遍历数据集所有数值特征,绘制两两特征的关系图(含散点图、直方图),你可以借助这些图来快速探索特征间相关性、分布规律
  • df
    传入预处理好的 DataFrame(含鸢尾花特征+标签),pairplot 会自动解析其列数据。
  • hue='species'
    species 列(鸢尾花品种:setosa、versicolor、virginica)对数据上色区分,让不同品种在图中呈现不同颜色,可以更加直观地对比分布差异

2. 次要优化:plt.suptitle("Feature Distribution by Species", y=1.02)

  • plt.suptitle
    给整个图像添加标题。
  • y=1.02
    调整标题位置。

3. 渲染输出:plt.show()

  • 触发 matplotlib 渲染,将 pairplot 生成的所有子图展示在窗口/Notebook 中。

看看运行的结果:

在这里插入图片描述

得到图片后,我们要对图片进行分析,

这里很重要,因为图片分析好坏决定我们后续使用哪类特征去进行节点分类!!!


下面先对图中画红框的图片进行分析,

也就是对角线的那四个图表,

准确来说,它应该是直方图。

在这里插入图片描述

我们通过以下表格来认识区分它。

这里算是单特征分布的示意图,展示单个特征在不同品种中的数值分布。

特征子图分布规律 & 品种差异
sepal length setosa花萼长度明显更小,集中在4.5-5.5cm;versicolor和 virginica重叠较多,范围5-7cm
sepal width setosa花萼宽度更大(3-4.5cm),另外两个品种集中在2.5-3.5cm,重叠严重
petal lengthsetosa花瓣长度极小,versicolor集中在3-5cm,virginica集中在4.5-7cm,区分度极高
petal widthsetosa花瓣宽度极小,versicolor集中在1-1.8cm,virginica集中在1.8-2.5cm,区分度很高

这四个图分析完后,

还剩下12个图,剩下的图就是双特征关联的散点图

这些图主要是展示特征间的关联关系,以及不同品种的聚类规律。

这里我们从中挑选2张图进行分析就好。

入下图中用红框标出来的图,

在这里插入图片描述

从这张图中,我们可以注意到:

  • setosa花瓣短、花萼也短,形成左下角的密集点;
  • versicolor花瓣长度中等,花萼长度中等;
  • virginica花瓣长、花萼也长,分布在右上角。

这里初步得出:花萼长度和花瓣长度正相关,且不同品种沿对角线分布,有一定区分度

再来,我们再看下面这张图:

在这里插入图片描述

同理,可以观察到:

  • setosa花瓣宽度窄、花萼宽度宽,形成左上角的点;
  • versicolor和 virginica花瓣宽度越大,花萼宽度也越大,但两者重叠较多。

得出初步结论,发现花萼宽度和花瓣宽度弱正相关,setosa 与其他品种区分明显,但versicolor和 virginica 难通过这对特征区分。

接着再通过比对右下角两个图以及左侧三个正相关的图,

最终得出如下总结:

  1. 最具区分度的特征petal lengthpetal width
  2. 次优特征sepal length(能辅助区分 versicolor 和 virginica)。
  3. 最无效特征sepal width(因为versicolor 和 virginica 重叠严重,难以单独区分)。

这里尤其是 setosa 与其他品种的边界非常清晰,

所以后续用决策树/分类模型时,petal lengthpetal width优先成为分裂节点!!!

因为它们能最大程度降低分类不确定性;

sepal width 可能在深层节点才被使用,或对模型贡献较小。

本小节虽然程序不多,理解难度也比较低,

但是核心难点在于我们如何通过得到的图片进行判断,

通过观察分布和聚类规律,预判哪些特征对分类最有效,避免盲目建模,让后续模型训练更高效!

4.4 准备用于模型的数据

X = df.drop('species', axis=1)
y = df['species']# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)print("Training samples:", X_train.shape[0])
print("Testing samples:", X_test.shape[0])

老样子,先对程序进行分析。

核心在于做数据准备,分为分离特征与标签划分训练集和测试集两个关键步骤

1.分离特征(X)和标签(y)

# Features and labels
X = df.drop('species', axis=1)
y = df['species']
  • X = df.drop('species', axis=1)

    • df 是包含鸢尾花数据的 DataFrame(有4个特征列 + 1个 species 标签列)。
    • drop('species', axis=1) 表示删除 species 这一列axis=1 代表操作列),剩下的列作为特征数据(X),用于训练模型输入
    • 最终 X 包含4列:sepal lengthsepal widthpetal lengthpetal width
  • y = df['species']

    • 提取 species 这一列作为标签(y),作为模型需要预测的输出

2.划分训练集和测试集

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
  • train_test_split
    这是 sklearn 提供的数据集拆分工具,作用是把特征(X)和标签(y)随机拆成两部分
    • 训练集(X_train, y_train)喂给模型学习规律;

    • 测试集(X_test, y_test):验证模型学得好不好;

    • test_size=0.3:选择常见的37开,测试集占总数据的 30%,训练集则占 70%;

    • random_state=42:固定随机种子,让每次拆分结果完全一致

3.输出数据集大小

print("Training samples:", X_train.shape[0])
print("Testing samples:", X_test.shape[0])
  • X_train.shape[0] 取训练集的样本数量;
  • X_test.shape[0] 取测试集的样本数量;

验证拆分比例是否正确,确保数据准备无误。

本小节的内容也是比较重要的,

算得上是必经步骤,

后续再为模型准备数据的时候牢记三步走:

  1. 分离特征X和标签y,让模型明确输入输出
  2. 拆分训练集/测试集,用训练集学规律、测试集验效果的方式,避免模型过拟合
  3. 固定 random_state 让实验可复现,方便调试和对比不同模型的效果。

这一步骤做好后,后续建模就可以直接用 X_train, y_train 训练模型,再用 X_test, y_test 评估效果,整个流程就打通啦。

在这里插入图片描述
150×30%=45(测试集样本数)150×70%=105(训练集样本数) 150 \times 30\% = 45 \quad \text{(测试集样本数)} \\ 150 \times 70\% = 105 \quad \text{(训练集样本数)} 150×30%=45(测试集样本数)150×70%=105(训练集样本数)

输出的结果跟预期计算的结果是一致的,

这个流程就顺利完成了。

4.5 训练决策树模型

clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
clf.fit(X_train, y_train)

对程序进行分析,

clf = DecisionTreeClassifier(criterion='entropy', max_depth=3, random_state=42)
  • DecisionTreeClassifier
    sklearn 决策树分类器的,初始化时通过参数配置模型的结构和训练规则

  • 参数解释见下表

    参数名作用你的配置值
    criterion决策树分裂依据,可选 gini(默认)或 entropy(信息熵)entropy
    max_depth限制树的最大深度,避免过拟合(值越小模型越简单)3
    random_state固定随机种子,让每次训练结果一致(可复现)42
clf.fit(X_train, y_train)
  • fit 是模型的训练方法,传入训练集特征(X_train训练集标签(y_train,让决策树学习特征→标签的映射规律。
  • 训练过程:决策树根据 criterion='entropy' 计算信息熵,不断分裂节点,直到树深达到 max_depth=3,最终形成一棵有决策逻辑的树。

这里对函数内部参数选择的具体数值进行解析:

  • criterion='entropy':让决策树优先选择信息增益大的特征分裂;
  • max_depth=3:避免树过深,让模型更通用;
  • random_state=42:能得到完全一样的决策树,否则每次训练结果不同。

在这里插入图片描述
运行后这里就会显示你设置的各类参数。

4.6 可视化决策树

plt.figure(figsize=(15,10))
plot_tree(clf, filled=True, feature_names=X.columns, class_names=clf.classes_)
plt.title("🌳 Decision Tree Visualization", fontsize=16)
plt.show()

这段代码将抽象的决策树规则转化为直观的图形,方便理解模型的分类逻辑。

1. plt.figure(figsize=(15,10))
创建一个绘图窗口(画布),指定画布宽度为15英寸,高度为10英寸

这里根据具体要求来,数值越大,图越清晰,如果要展示复杂的决策树,可以在这个基础上加点。

2. plot_tree(...)(核心可视化函数)

plot_tree(clf, filled=True, feature_names=X.columns, class_names=clf.classes_)
  • plot_tree 直接将训练好的决策树(clf)转化为图形。
  • 关键参数解析:
    • clf:传入已训练好的决策树模型;
    • filled=True:为树的节点填充颜色,颜色深浅表示节点的纯度(纯度越高,颜色越深,直观区分不同类别占比)。
    • feature_names=X.columns:指定特征名称,替代默认的 X[0]、X[1] 等抽象编号,让节点的判断条件更易读。
    • class_names=clf.classes_:指定类别名称(如 setosa、versicolor、virginica),替代默认的数字标签,让叶节点的分类结果更直观。

3. plt.title("🌳 Decision Tree Visualization", fontsize=16)
为整个决策树图添加标题并设置标题字体大小为16。

4. plt.show()
用matplotlib渲染并显示图像,将绘制好的决策树展示在屏幕上。

运行程序看看具体效果:

在这里插入图片描述

我们来对这张图片进行分析:

首先来分析根节点,见下图:
在这里插入图片描述

接下来看一下根节点的分支

在这里插入图片描述
然后看一下第二层的分支

在这里插入图片描述

接着看一下第三层节点

在这里插入图片描述
最后看一下叶节点

在这里插入图片描述

好啦,到这里图片就都分析完了,

我们可以从中获取到一些规律并进行总结。

分为3点来阐述:

1. 最关键的特征: petal length(花瓣长度)

根节点直接用它拆分,把 setosa 一键筛出;

第二层继续用它拆分 versicolor 和 virginica。

2. 最明显的边界 → setosa 的花瓣长度

setosa 几乎全部分布在花瓣长度 ≤2.45cm,和另外两类完全分开。

3. 最易混淆的类别 → versicolor 和 virginica

前两层筛不掉,需要靠 花瓣宽度(versicolor 窄)和 更长的花瓣长度(virginica 长)进一步区分。

总结
在特征选择方面,模型自动选 petal length 作为根节点,说明花瓣长度是区分鸢尾花最核心的特征,和我们之前可视化 pairplot 的结论一致!

但与此同时,还是发现过拟合风险,比如叶节点里,浅紫色节点熵 = 0.918,还有 5 个 versicolor 没分干净,这说明当树深 = 3 时,仍有少量混杂,但整体不影响大分类。如果继续加深树(如 max_depth=4),可能会导致过拟合!

本小节用了大量篇幅来进行讲解,就是为了说明一点,那就是通过可视化,我们能直观看懂模型的决策逻辑,比如模型优先用哪个特征划分数据,不同特征的判断阈值是多少,最终如何得出分类结果。这对于理解模型、解释预测结果,如“为什么这个样本被分为 virginica”以及调参优化,如判断树深是否合适都非常重要!!!

4.7 评估模型

y_pred = clf.predict(X_test)print("🔍 Classification Report:\n")
print(classification_report(y_test, y_pred))print("✅ Accuracy Score:", accuracy_score(y_test, y_pred))

可以使用以上程序去评估一下这个模型,

通过预测结果与真实标签的对比,从而量化模型的分类效果

下面先分析下这段程序:

1. y_pred = clf.predict(X_test)

用训练好的决策树模型(clf)对测试集特征(X_test)进行预测,得到模型的预测标(y_pred

在这个过程中,模型会根据之前学到的决策规则,对测试集中的45个样本逐一判断,输出每个样本的预测品种。

2. 输出分类报告:classification_report(y_test, y_pred)

print("🔍 Classification Report:\n")
print(classification_report(y_test, y_pred))

classification_report 是 sklearn 提供的综合评估工具,对比真实标签(y_test)和预测标签(y_pred),输出每个类别的详细指标,具体可以见下表:

指标含义
precision精确率:预测为某品种的样本中,真正属于该品种的比例
recall召回率:该品种的真实样本中,被模型正确预测出来的比例
f1-score精确率和召回率的调和平均(综合两者,值越接近1越好)
support该品种在测试集中的真实样本数量(如测试集中有14个setosa)

最后, macro avg为宏平均,所有类别平等加权;

weighted avg为加权平均,按样本数加权,评估模型整体表现。

3. 输出准确率:accuracy_score(y_test, y_pred)

print("✅ Accuracy Score:", accuracy_score(y_test, y_pred))

准确率(Accuracy):所有测试样本中,预测正确的比例(正确预测数 / 总样本数)。

我们运行一下程序,看看结果:

在这里插入图片描述

对图片进行分析,

这里可以用一张表来概括说明更好。

1.逐类别分析(setosa、versicolor、virginica)

类别precision(精确率)recall(召回率)f1-score(综合分)support(真实样本数)解读
setosa1.001.001.0019模型完美识别!预测为setosa的样本,100%是真setosa;所有真实setosa也全被找出来了
versicolor1.000.920.9613精确率满分(预测versicolor的样本都真),但召回率稍低(13个真实样本中,1个被误判)
virginica0.931.000.9613召回率满分(真实virginica全被找),但精确率稍低(预测virginica的样本中,7%是误判)

2.整体指标(accuracy、macro avg、weighted avg)

指标数值解读
accuracy0.98整体准确率:45个测试样本中,98%预测正确,仅1-2个样本分类错误
macro avg0.98/0.97/0.97宏平均:综合三类的精确率、召回率、F1,整体表现极佳
weighted avg0.98/0.98/0.98加权平均:因setosa表现极佳,拉高了整体分数,模型稳健

总体的表现性能很优秀,但有些地方还有一定的优化空间,下面给出一些优化方向。

  1. 调整决策树深度
    当前 max_depth=3,若增至4,模型可能学习更细的边界,区分开那2个误判样本。

  2. 验证测试集分布
    测试集中 versicolor 和 virginica 这两类样本较少,若有更多样本,模型表现会更稳定。

本小节,通过验证结果可知模型整体表现优秀,setosa 完美识别,仅在 versicolor 和 virginica 的边界样本上有微小失误。对于鸢尾花这种简单数据集,当前决策树的配置已经足够了!

5 小结

本文是前文【机器学习-2】的进阶,在前文中主要学习了决策树的原理,在这里就用一个具体的项目来进行学习和更深入的理解,本文算是实践篇。

在本次基于鸢尾花数据集的决策树实验中,模型表现优异。借助 scikit-learn 仅需数行代码,即可完成从数据加载、训练到可视化的全流程;通过分类报告可知,决策树在测试集上准确率达 98%,且可视化决策树清晰展现了花瓣长度、宽度等特征阈值如何驱动分类,模型逻辑直观。

若追求更优效果,可从三方面迭代:一是调参探索(如调整 max_depth 平衡复杂度、修改 min_samples_split 控制节点分裂、切换 criterion 对比 entropy 与 gini 差异);二是剪枝优化(通过预剪枝 / 后剪枝限制过拟合,如设置 ccp_alpha);三是跨场景验证,尝试泰坦尼克生存预测、葡萄酒质量分类等真实数据集,检验决策树在复杂业务中的泛化能力,进一步挖掘模型潜力。

http://www.xdnf.cn/news/1200043.html

相关文章:

  • 【Typora】分享一款很好用的PJ版本的Markdown编辑器
  • k8s pod生命周期、初始化容器、钩子函数、容器探测、重启策略
  • S7-1500 与 S7-1200 存储区域保持性设置特点详解
  • ESP32学习-FreeRTOS队列使用指南与实战
  • 回归预测 | MATLAB实现BiTCN双向时间卷积神经网络多输入单输出回归预测
  • 如何在 Ubuntu 24.04 或 22.04 中更改 SSH 端口
  • 个人笔记HTML5
  • 【ee类保研面试】通信类---信息论
  • [2025CVPR-图象超分辨方向]DORNet:面向退化的正则化网络,用于盲深度超分辨率
  • 标签驱动的可信金融大模型训练全流程-Agentar-Fin-R1工程思路浅尝
  • Unity Catalog与Apache Iceberg如何重塑Data+AI时代的企业数据架构
  • JavaEE初阶第十二期:解锁多线程,从 “单车道” 到 “高速公路” 的编程升级(十)
  • LeetCode 239:滑动窗口最大值
  • 模拟实现python的sklearn库中的Bunch类以及 load_iris 功能
  • RocksDB 高效采样算法:水塘抽样和随机寻址
  • WAIC 2025 热点解读:如何构建 AI 时代的“视频神经中枢”?
  • [N1盒子] 斐讯盒子N1 T1通用刷机包(可救砖)
  • SpringBoot 整合 Langchain4j AIService 深度使用详解
  • Valgrind Helgrind 工具全解:线程同步的守门人
  • 编程语言Java——核心技术篇(五)IO流:数据洪流中的航道设计
  • JavaWeb(苍穹外卖)--学习笔记13(微信小程序开发,缓存菜品,Spring Cache)
  • Java中get()与set()方法深度解析:从封装原理到实战应用
  • 8. 状态模式
  • 零基础 “入坑” Java--- 十五、字符串String
  • 一场关于电商零售增长破局的深圳探索
  • 金融科技中的跨境支付、Open API、数字产品服务开发、变革管理
  • 【Ollama】大模型本地部署与 Java 项目调用指南
  • 字符串是数据结构还是数据类型?
  • 基于Prometheus+Grafana的分布式爬虫监控体系:构建企业级可观测性平台
  • Git Commit 生成与合入 Patch 指南