KNN算法详解:从原理到实战(鸢尾花分类 手写数字识别)
KNN算法详解:从原理到实战(鸢尾花分类 & 手写数字识别)
摘要:本文系统讲解K近邻算法(K-Nearest Neighbors, KNN)的核心思想、分类与回归流程、距离度量方法、特征预处理的重要性,并结合鸢尾花分类和手写数字识别两个经典案例,深入剖析KNN的实际应用。同时介绍交叉验证与网格搜索等超参数调优技术,帮助读者全面掌握KNN算法的理论与实践。
一、KNN算法简介
KNN(K-Nearest Neighbors)是一种简单而有效的监督学习算法,常用于分类和回归任务。其核心思想是:“物以类聚”——一个样本的类别或值由其在特征空间中最邻近的k个样本决定。
核心思想
如果一个样本在特征空间中的k个最相似(即距离最近)的样本大多数属于某一类别,则该样本也倾向于属于这个类别。
通俗理解:“看你的邻居是谁,就知道你是谁。”
二、K值的选择
K值是KNN算法中最重要的超参数之一,直接影响模型的性能:
- K值过小:模型对噪声敏感,容易过拟合。
- K值过大:模型趋于平滑,可能忽略局部特征,导致欠拟合。
通常通过交叉验证 + 网格搜索来选择最优K值。
三、KNN的两种应用方式
1. 分类问题处理流程
- 计算未知样本到每个训练样本的距离;
- 将训练样本按距离升序排列;
- 取出距离最近的K个样本;
- 统计这K个样本中各类别出现的频次;
- 将未知样本归为频次最高的类别(多数表决)。
2. 回归问题处理流程
- 计算未知样本到每个训练样本的距离;
- 将训练样本按距离升序排列;
- 取出距离最近的K个样本;
- 计算这K个样本目标值的平均值;
- 将该平均值作为未知样本的预测值。
四、常用距离度量方法
样本之间的“相似性”通常通过距离来衡量。常见距离度量包括:
距离类型 | 公式 | 特点 |
---|---|---|
欧氏距离(Euclidean) | d=∑i=1n(xi−yi)2d = \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2}d=∑i=1n(xi−yi)2 | 最常用,表示两点间的直线距离 |
曼哈顿距离(Manhattan) | d=∑i=1n∣xi−yi∣d = \sum_{i=1}^{n}|x_i - y_i|d=∑i=1n∣xi−yi∣ | 城市街区距离,对异常值较鲁棒 |
切比雪夫距离(Chebyshev) | d=maxi∣xi−yi∣d = \max_i |x_i - y_i|d=maxi∣xi−yi∣ | 各维度最大差值 |
闵可夫斯基距离(Minkowski) | d=(∑i=1n∣xi−yi∣p)1/pd = \left(\sum_{i=1}^{n}|x_i - y_i|^p\right)^{1/p}d=(∑i=1n∣xi−yi∣p)1/p | 欧氏和曼哈顿的泛化形式,当p=2为欧氏,p=1为曼哈顿 |
五、特征预处理:为什么需要归一化/标准化?
问题背景
当特征的量纲或数值范围差异较大时(如身高cm vs 体重kg),某些特征会因数值大而在距离计算中占据主导地位,影响模型学习效果。
解决方案
归一化(Normalization)
将数据缩放到固定区间(如[0,1])。
- API:
sklearn.preprocessing.MinMaxScaler
- 公式:x′=x−xminxmax−xminx' = \frac{x - x_{min}}{x_{max} - x_{min}}x′=xmax−xminx−xmin
- 优点:保留原始分布形状
- 缺点:对异常值敏感(受极值影响大)
标准化(Standardization)
将数据转换为均值为0、标准差为1的标准正态分布。
- API:
sklearn.preprocessing.StandardScaler
- 公式:x′=x−μσx' = \frac{x - \mu}{\sigma}x′=σx−μ
- 优点:对异常值不敏感,适合大多数机器学习模型
- 缺点:结果无固定范围
推荐使用标准化,尤其在数据分布接近正态时。
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test) # 注意:测试集使用训练集的参数
六、实战案例1:KNN实现鸢尾花分类
数据集介绍
鸢尾花数据集(Iris Dataset)包含150条样本,每类50条,共3个类别:
- Setosa
- Versicolour
- Virginica
每条样本有4个特征:
- 花萼长度(sepal length)
- 花萼宽度(sepal width)
- 花瓣长度(petal length)
- 花瓣宽度(petal width)
完整代码实现
# 导入库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score# 1. 加载数据
iris = load_iris()
X, y = iris.data, iris.target# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=22)# 3. 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 4. 训练KNN模型
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)# 5. 预测与评估
y_pred = knn.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"测试集准确率: {accuracy:.4f}")
输出示例:
测试集准确率: 0.9667
七、超参数调优:交叉验证 + 网格搜索
交叉验证和网格搜索的基本概念补充
交叉验证(Cross-Validation)
什么是交叉验证?
交叉验证是一种对数据集进行划分的方法,用于评估模型的泛化能力。它通过将原始数据集划分为多个子集,并在不同的子集上进行训练和验证,从而减少由于数据集划分带来的偏差。
具体步骤:
- 数据集划分:将数据集随机划分为 ( k ) 个大小相等的子集(称为“折”),通常 ( k ) 的值为5或10。
- 轮换训练与验证:
- 第一次:选择第1个子集作为验证集,其余 ( k-1 ) 个子集合并为训练集。
- 第二次:选择第2个子集作为验证集,其余 ( k-1 ) 个子集合并为训练集。
- 以此类推,共进行 ( k ) 次训练和验证。
- 性能评估:计算每次验证的性能指标(如准确率、F1分数等),并取平均值作为最终的模型评估结果。
优点:
- 减少偏差:通过多次训练和验证,减少了由于单次数据划分带来的偶然性误差。
- 充分利用数据:每个样本都有机会被用作验证集,提高了数据利用率。
缺点:
- 计算成本高:需要进行 ( k ) 次训练和验证,计算量较大。
应用场景:
- 适用于小规模数据集,能够更准确地评估模型性能。
- 在超参数调优中,常与网格搜索结合使用。
网格搜索(Grid Search)
为什么需要网格搜索?
在机器学习模型中,超参数的选择对模型性能有重要影响。手动调整超参数不仅耗时费力,而且难以找到最优组合。网格搜索提供了一种系统化的方法来自动寻找最优超参数。
基本概念:
- 超参数空间:定义所有可能的超参数组合。
- 网格搜索过程:遍历超参数空间中的每一个组合,对每个组合训练一个模型,并评估其性能。
具体步骤:
- 定义超参数范围:确定每个超参数的候选值集合。
- 构建超参数组合:生成所有可能的超参数组合。
- 模型训练与评估:对于每组超参数,使用交叉验证方法训练和评估模型。
- 选择最优组合:根据评估结果,选择性能最佳的超参数组合。
优点:
- 全面搜索:确保不会遗漏任何潜在的最优组合。
- 自动化:减少了手动调整超参数的工作量。
缺点:
- 计算复杂度高:随着超参数数量和候选值的增加,计算量呈指数级增长。
- 可能陷入局部最优:在高维超参数空间中,可能存在多个局部最优解。
应用场景:
- 适用于超参数较少且计算资源充足的情况。
- 常与交叉验证结合使用,以提高模型评估的准确性。
网格搜索 + 交叉验证的强强组合
模型选择和调优:
- 交叉验证解决数据输入问题:通过多次训练和验证,得到更可靠的模型评估结果,避免了数据划分带来的偶然性误差。
- 网格搜索寻找最优超参数组合:遍历所有可能的超参数组合,找到性能最佳的模型配置。
- 综合应用:两者结合形成一个完整的模型参数调优解决方案,既保证了模型评估的准确性,又找到了最优的超参数配置。
为了找到最优的K值,我们可以使用GridSearchCV进行自动搜索。
from sklearn.model_selection import GridSearchCV# 定义参数搜索空间
param_grid = {'n_neighbors': range(1, 11)}# 创建KNN分类器
knn = KNeighborsClassifier()# 网格搜索 + 5折交叉验证
grid_search = GridSearchCV(knn, param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)# 输出最优参数
print("最优K值:", grid_search.best_params_)
print("交叉验证最高得分:", grid_search.best_score_)# 使用最优模型预测
best_knn = grid_search.best_estimator_
test_score = best_knn.score(X_test, y_test)
print("测试集得分:", test_score)
输出示例:
最优K值: {'n_neighbors': 6} 交叉验证最高得分: 0.975 测试集得分: 0.9667
八、实战案例2:KNN实现手写数字识别(MNIST)
数据集介绍
- 图像大小:28×28 像素 → 共784个特征
- 像素值范围:[0, 255](0为白色,255为黑色)
- 标签:0~9 的数字
- 数据文件:
train.csv
(含标签),test.csv
(无标签)
完整流程
1. 数据加载与可视化
import pandas as pd
import matplotlib.pyplot as plt# 加载数据
data = pd.read_csv('train.csv')
X = data.iloc[:, 1:] # 像素特征
y = data.iloc[:, 0] # 标签# 显示某张图像
def show_digit(idx):digit = X.iloc[idx].values.reshape(28, 28)plt.imshow(digit, cmap='gray')plt.title(f"Label: {y.iloc[idx]}")plt.axis('off')plt.show()show_digit(0) # 显示第一张图
2. 特征归一化
由于像素值范围为[0,255],需归一化到[0,1]:
X = X / 255.0 # 缩放到0~1
3. 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=21
)
stratify=y
:保持训练/测试集中各类别比例一致。
4. 模型训练与评估
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)# 评估
score = knn.score(X_test, y_test)
print(f"手写数字识别准确率: {score:.4f}")
输出示例:
手写数字识别准确率: 0.9681
5. 模型保存与预测新图像
import joblib# 保存模型
joblib.dump(knn, 'knn_mnist.pkl')# 加载模型并预测新图像
model = joblib.load('knn_mnist.pkl')
img = plt.imread('demo.png').reshape(1, -1) / 255.0 # 预处理
pred = model.predict(img)
print("预测结果:", pred[0])
九、总结
要点总结
内容 | 关键点 |
---|---|
KNN思想 | “近朱者赤”,基于邻居投票 |
K值选择 | 小K易过拟合,大K易欠拟合,建议用网格搜索 |
距离度量 | 欧氏最常用,曼哈顿对异常值更鲁棒 |
特征预处理 | 必须做!推荐使用StandardScaler |
应用场景 | 分类(多数表决)、回归(取均值) |
调优方法 | GridSearchCV + cross-validation |
结语
KNN虽然原理简单,但在小样本、低维数据上表现优异,是理解机器学习“相似性”概念的绝佳起点。掌握其核心思想与工程实践技巧,将为后续学习更复杂的模型打下坚实基础。
提示:在高维数据(如图像)上,KNN计算开销大,可考虑降维(PCA)或使用更高效的近似算法(如KD-Tree)。
需要代码和数据的朋友:跳转gitee
代码路径自行更改