当前位置: 首页 > news >正文

【数学建模学习笔记】机器学习分类:KNN分类

KNN 分类算法详解(Python 实现)

KNN(K 近邻)算法是机器学习中最简单的监督学习分类算法之一。本文将为初学者详细解释 KNN 的原理和实现过程,并通过实际代码示例帮助理解。

什么是 KNN 算法?

KNN 的核心思想非常简单:物以类聚,人以群分。当我们要对一个新样本进行分类时,只需看看它周围最近的 K 个样本属于什么类别,然后让新样本 "跟随多数派"。

想象你在一个派对上,想知道自己应该坐在哪个区域。你看看离你最近的 3 个人(K=3)坐在哪里,如果 2 个在 A 区,1 个在 B 区,你很可能也会选择 A 区。这就是 KNN 的基本思想。

实现步骤

让我们一步步实现 KNN 分类算法,从数据准备到最终评估。

1. 导入需要的库

首先,我们需要导入一些 Python 库来帮助我们处理数据和实现算法:

import numpy as np  # 用于数值计算
import pandas as pd  # 用于数据处理
from sklearn.model_selection import train_test_split  # 用于划分训练集和测试集
from sklearn.preprocessing import StandardScaler  # 用于数据标准化
from sklearn.neighbors import KNeighborsClassifier  # KNN分类器
from sklearn.metrics import accuracy_score, classification_report  # 评估指标
import matplotlib.pyplot as plt  # 用于可视化
from mpl_toolkits.mplot3d import Axes3D  # 用于3D可视化

2. 加载并查看数据

我们使用一个包含收入、年龄、学历和对应类别的数据集:

# 加载数据
df = pd.read_excel('https://labfile.oss.aliyuncs.com/courses/40611/K%E8%BF%91%E9%82%BB%28KNN%29%E5%88%86%E7%B1%BB.xlsx')# 查看前5行数据
df.head()

数据看起来是这样的:

收入年龄学历类别
495622专科普通人
1039721本科精英
515830专科普通人
436822专科普通人
1268035专科精英

3. 数据预处理

计算机不理解文字,所以我们需要将中文类别转换为数字:

# 将学历文本转换为数字
df['学历'] = df['学历'].map({"专科": 1,"本科": 2,"硕士": 3,"博士": 4
})# 将类别文本转换为数字
df['类别'] = df['类别'].map({"普通人": 0,"精英": 1,"高质量": 2
})# 查看转换后的数据
df.head()

为了符合国际惯例,我们将中文列名改为英文:

column_mapping = {'收入': 'Income','年龄': 'Age','学历': 'Education_Level','类别': 'Category'
}
df.rename(columns=column_mapping, inplace=True)

4. 准备特征和标签

在机器学习中,我们通常将数据分为 "特征"(输入)和 "标签"(输出):

# 特征:收入、年龄、学历(用于预测的属性)
X = df[['Income', 'Age', 'Education_Level']]# 标签:类别(我们要预测的结果)
y = df['Category']

5. 数据标准化

KNN 算法基于距离计算,不同特征的单位可能不同(如收入是几千,年龄是几十),这会影响距离计算。因此我们需要标准化数据:

# 创建标准化器
scaler = StandardScaler()# 对特征进行标准化处理
X_scaled = scaler.fit_transform(X)

标准化后的数据均值为 0,标准差为 1,确保每个特征对距离计算的影响是均衡的。

6. 划分训练集和测试集

我们需要一部分数据来训练模型,另一部分来测试模型的效果:

# 划分数据集:80%用于训练,20%用于测试
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42
)
  • X_trainy_train:训练数据和对应的标签
  • X_testy_test:测试数据和对应的标签
  • test_size=0.2:测试集占 20%
  • random_state=42:固定随机种子,确保结果可重现

7. 实现 KNN 分类

现在我们可以创建并训练 KNN 模型了:

# 创建KNN分类器,K=3(选择最近的3个邻居)
knn = KNeighborsClassifier(n_neighbors=3)# 用训练数据训练模型
knn.fit(X_train, y_train)# 用训练好的模型预测测试集
y_pred = knn.predict(X_test)

8. 评估模型效果

我们需要看看模型预测得准不准:

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
# 生成详细的分类报告
report = classification_report(y_test, y_pred)print(f'准确率: {accuracy}')
print('分类报告:\n', report)

输出结果:

准确率: 0.85
分类报告:precision    recall  f1-score   support0       1.00      1.00      1.00        241       0.93      0.62      0.74        212       0.64      0.93      0.76        15accuracy                           0.85        60macro avg       0.85      0.85      0.83        60
weighted avg       0.88      0.85      0.85        60
  • 准确率 (accuracy):85%,表示总体预测正确的比例
  • 精确率 (precision):预测为某类且实际为此类的比例
  • 召回率 (recall):实际为某类且被正确预测的比例
  • F1 分数:精确率和召回率的调和平均

9. 可视化分类结果

我们可以用 3D 图直观地展示分类结果:

# 创建3D图形
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')# 绘制测试数据点,用颜色表示预测的类别
scatter = ax.scatter(X_test[:, 0],  # 收入X_test[:, 1],  # 年龄X_test[:, 2],  # 学历c=y_pred,      # 用预测类别着色cmap='viridis' # 颜色映射
)# 设置坐标轴标签
ax.set_xlabel('Income (标准化)')
ax.set_ylabel('Age (标准化)')
ax.set_zlabel('Education Level (标准化)')# 调整视角
ax.view_init(elev=30, azim=70)# 添加图例
legend1 = ax.legend(*scatter.legend_elements(), title="Category")
ax.add_artist(legend1)plt.show()

K 值的选择

K 值是 KNN 算法中最重要的参数:

  • K 值太小:模型容易过拟合,对噪声敏感
  • K 值太大:模型过于简单,可能忽略数据的真实模式

通常可以通过尝试不同的 K 值(如 1, 3, 5, 7...),选择性能最好的那个。

# 尝试不同的K值
for k in [1, 3, 5, 7, 9, 11]:knn = KNeighborsClassifier(n_neighbors=k)knn.fit(X_train, y_train)y_pred = knn.predict(X_test)accuracy = accuracy_score(y_test, y_pred)print(f'K={k} 时的准确率: {accuracy:.2f}')

距离计算

KNN 中常用的距离计算方法:

  1. 欧氏距离:最常用,计算两点之间的直线距离
  2. 曼哈顿距离:计算两点在坐标轴上的绝对距离之和
  3. 余弦相似度:衡量两个向量方向的相似性

在 scikit-learn 中,KNeighborsClassifiermetric参数可以指定距离度量方法,默认为欧氏距离。

KNN 的优缺点

优点

  • 简单易懂,容易实现
  • 不需要训练过程,是一种 "惰性学习" 算法
  • 可以处理多分类问题
  • 对异常值不敏感(当 K 较大时)

缺点

  • 计算量大,预测速度慢(需要计算与所有样本的距离)
  • 对高维数据效果不好("维度灾难")
  • 对不平衡数据敏感
  • 需要存储所有训练样本,内存占用大
http://www.xdnf.cn/news/1458379.html

相关文章:

  • Full cycle of a machine learning project|机器学习项目的完整周期
  • 9.4C++——继承
  • MySQL命令--备份和恢复数据库的Shell脚本
  • C++工程实战入门笔记11-三种初始化成员变量的方式
  • TCP协议的三次握手与四次挥手深度解析
  • 从头开始学习AI:第二篇 - 线性回归的数学原理与实现
  • 基础crud项目(前端部分+总结)
  • Flink反压问题
  • 算法 --- 分治(归并)
  • 【Markdown转Word完整教程】从原理到实现
  • VOC、COCO、YOLO、YOLO OBB格式的介绍
  • AgentThink:一种在自动驾驶视觉语言模型中用于工具增强链式思维推理的统一框架
  • 深入剖析Spring Boot / Spring 应用中可自定义的扩展点
  • elasticsearch学习(五)文档CRUD
  • 基于脚手架微服务的视频点播系统-界面布局部分(二):用户界面及系统管理界面布局
  • 02-ideal2025 Ultimate版安装教程
  • SPI flash挂载fatfs文件系统
  • 什么是静态住宅IP 跨境电商为什么要用静态住宅IP
  • More Effective C++ 条款28:智能指针
  • 稠密矩阵和稀疏矩阵的对比
  • 神马 M21 31T 矿机解析:性能、规格与市场应用
  • Python多序列同时迭代完全指南:从基础到高并发系统实战
  • vcruntime140_1.dll缺失?5个高效解决方法
  • 手机秒变全栈IDE:Claude Code UI的深度体验
  • SpringBoot实现国际化(多语言)配置
  • MySQL 8.0 主从复制原理分析与实战
  • 深入解析Java HashCode计算原理 少看大错特错的面试题
  • 多线程——线程状态
  • 并发编程——17 CPU缓存架构详解高性能内存队列Disruptor实战
  • ResNet(残差网络)-彻底改变深度神经网络的训练方式