K-近邻算法(KNN算法)的K值的选取--交叉验证+网格搜索
选择K近邻算法(KNN)中的最佳K值是至关重要的,因为它直接影响到模型的性能(偏差-方差权衡)。交叉验证(Cross-Validation)结合网格搜索(Grid Search) 是解决这个问题的标准且强大的方法。
一.为什么需要交叉验证和网格搜索
K太小 ( K=1): 模型对训练数据中的噪声或局部波动非常敏感。决策边界变得非常崎岖。模型具有低偏差但高方差,容易过拟合。
K太大 (K=接近样本总数): 模型过于平滑,忽略了数据的局部结构。决策边界趋于线性。模型具有高偏差但低方差,容易欠拟合。
目标: 找到一个“恰到好处”的K值,在偏差和方差之间取得最佳平衡,从而获得最佳的泛化性能(在新数据上的预测能力)
二.交叉验证
交叉验证(Cross-Validation)是机器学习和统计建模中评估模型性能、选择模型、调整超参数的核心技术。它的核心思想是重复利用有限的数据集,通过多次划分训练集和测试集来获得更可靠、稳定的模型性能估计。
核心作用
更可靠地评估模型泛化能力:
传统的单次训练集/测试集划分(如 70%/30%)存在很大的随机性。不同的划分可能导致模型性能评估结果差异很大(高方差)。
交叉验证通过多次划分、多次训练和测试,计算性能指标(如准确率、精确率、召回率、F1、MSE、MAE、R²等)的平均值,显著降低评估结果对单次数据划分的依赖,提供更稳定、更可信的泛化误差估计。
有效利用有限数据:
在数据集较小的情况下,如果只用一小部分作为测试集(如 30%),用于训练的数据量会进一步减少,可能不足以训练出好的模型。
如果留出较大的测试集,评估结果可能不够可靠(测试样本太少)。
交叉验证(尤其是 k 折)让每个数据点都有机会作为测试集的一部分被评估一次,同时最大化了用于训练的数据量(每次训练使用约 (k-1)/k 的数据)。
交叉验证在 KNN 中的核心作用
负责评估模型泛化能力,避免过拟合,为网格搜索提供可靠的性能指标。
可靠性验证:
将数据集划分为k
份(如5或10份),轮流用其中k-1
份训练,剩余1份验证,重复k
次。综合k
次结果得到平均性能(如准确率),减少因数据划分随机性导致的评估偏差。防止过拟合:
避免使用单一训练/测试集划分时可能出现的“巧合性高得分”,确保模型在未见数据上表现稳定。支持参数选择:
为网格搜索中每组超参数提供客观的性能评分依据。
三.网格搜索
网格搜索(Grid Search)是一种系统化、穷举式的超参数优化方法。它的核心目标是在指定的超参数空间中找到能使模型在给定评估指标上表现最佳的超参数组合。
网格搜索的作用
自动化超参数调优:
最核心的作用。它取代了手动尝试不同参数组合的繁琐、低效且容易遗漏的过程,实现了调优的自动化。系统地探索参数空间:
通过穷举定义好的网格中的所有组合,确保不会遗漏网格定义范围内的任何可能组合。这比随机尝试或手动选择更全面(在定义的网格内)。找到(在指定网格内)最优组合:
其目标是明确地找到在给定的参数网格和评估指标下,性能最优的超参数配置。结果具有可重复性。提升模型性能:
通过找到更优的超参数组合,通常能显著提升模型在验证集/测试集上的泛化性能。理解参数影响:
虽然主要目的是找最优解,但查看不同参数组合的得分结果,也能帮助理解不同超参数及其取值对模型性能的影响趋势(例如,某个参数增大时模型性能如何变化)。
网格搜索的优缺点
优点:
原理简单直观,易于理解和实现。
结果可重复。
只要网格定义合理且计算资源足够,就能找到网格内的全局最优解(相对于定义的网格和评估指标)。
缺点:
计算成本高昂: 这是最大缺点。超参数的数量和每个参数的候选值数量增加会指数级地增加需要评估的组合总数(维度灾难)。对于大型模型或复杂网格,可能耗时极长甚至不可行。
网格定义依赖性强: 最佳结果严格依赖于你定义的参数网格。如果最佳参数不在你设定的网格点上(比如真实最优学习率是0.05,但你的网格是[0.01, 0.1, 1]),或者你设定的范围根本不包括最优值,网格搜索就找不到它。网格太粗可能错过最优,网格太细计算量爆炸。
维度灾难: 对高维参数空间(很多超参数需要调)非常不友好。
总结:网格搜索是一种通过穷举遍历预定义超参数组合网格,并利用交叉验证评估每个组合性能,从而自动化寻找最优超参数配置的方法。它的主要作用是提升模型性能和自动化调优过程。其核心优势在于简单性和在中小网格上的全局搜索能力,但最大的局限在于高昂的计算成本,尤其是在参数维度高或候选值多时。因此,在实际应用中,对于参数空间较大的情况,常常会先用较粗的网格搜索缩小范围,或者转而使用随机搜索(Randomized Search) 或更高级的贝叶斯优化(Bayesian Optimization)等方法。
网格搜索在 KNN 中的核心作用
负责自动化搜索最优超参数组合(如k
值、距离度量方式)。
作用
穷举参数组合:
预先定义超参数的候选值(如k = [3, 5, 10]
,距离 = ["欧氏距离", "曼哈顿距离"]
),遍历所有组合。依赖交叉验证评分:
对每组参数,调用交叉验证计算平均性能,选出得分最高的超参数组合。避免手动试错:
系统化替代人工调参,高效找到全局较优解。
四.二者协同工作流程
定义参数网格:确定待调参数(如k
、weights
、metric
)。
网格搜索启动:遍历参数网格中的每一组参数。
交叉验证执行:对当前参数组合,使用交叉验证计算模型性能均值。
选择最优参数:网格搜索比较所有参数组合的交叉验证得分,选择最优组合。
最终评估:用最优参数在独立测试集上验证模型性能。
graph LR
A[定义参数网格] --> B[网格搜索遍历参数]
B --> C[对每组参数执行交叉验证]
C --> D[计算交叉验证平均得分]
D --> E[选择得分最高的参数]
E --> F[在测试集评估最终模型]
技术 | 核心分工 | 关键作用 |
---|---|---|
交叉验证 | 评估模型泛化能力 | 提供可靠、稳定的性能评分,减少随机性影响 |
网格搜索 | 自动化搜索最优超参数 | 高效遍历参数空间,依赖交叉验证结果做决策 |
二者结合相当于:网格搜索是“调参员”,交叉验证是“公正的考官” —— 前者尝试不同配置,后者客观评分,共同确保KNN模型达到最佳泛化性能。
五.综合案例代码展示
# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif'] = ['SimHei']
# 加载鸢尾花数据集
iris = load_iris()
X, y = iris.data, iris.target
feature_names, target_names = iris.feature_names, iris.target_names# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y
)# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)# =================================
# 1. 使用交叉验证评估不同k值的性能
# =================================
print("1. 交叉验证评估不同k值性能")
k_values = list(range(1, 31))
cv_scores = []for k in k_values:knn = KNeighborsClassifier(n_neighbors=k)scores = cross_val_score(knn, X_train_scaled, y_train, cv=5, scoring='accuracy')cv_scores.append(scores.mean())# 找到最佳k值
best_k = k_values[np.argmax(cv_scores)]
print(f"交叉验证找到的最佳k值: {best_k}")
print(f"交叉验证最高准确率: {max(cv_scores):.4f}")# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(k_values, cv_scores, 'o-', color='blue')
plt.axvline(x=best_k, color='red', linestyle='--')
plt.title('K值选择与交叉验证准确率')
plt.xlabel('k值')
plt.ylabel('5折交叉验证准确率')
plt.grid(True)
plt.show()# =================================
# 2. 使用网格搜索优化多个超参数
# =================================
print("\n2. 网格搜索优化多个超参数")
# 定义参数网格
param_grid = {'n_neighbors': list(range(1, 31)),'weights': ['uniform', 'distance'],'p': [1, 2] # 1:曼哈顿距离, 2:欧氏距离
}# 创建网格搜索对象
grid_search = GridSearchCV(KNeighborsClassifier(),param_grid,cv=5,scoring='accuracy',n_jobs=-1
)# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)# 输出最佳参数
print(f"网格搜索找到的最佳参数: {grid_search.best_params_}")
print(f"网格搜索最高交叉验证准确率: {grid_search.best_score_:.4f}")# =================================
# 3. 评估最终模型
# =================================
print("\n3. 评估最优模型在测试集上的表现")
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test_scaled)
test_accuracy = accuracy_score(y_test, y_pred)print(f"测试集准确率: {test_accuracy:.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=target_names))# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('混淆矩阵')
plt.colorbar()
tick_marks = np.arange(len(target_names))
plt.xticks(tick_marks, target_names, rotation=45)
plt.yticks(tick_marks, target_names)
plt.ylabel('真实标签')
plt.xlabel('预测标签')# 添加数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):plt.text(j, i, format(cm[i, j], 'd'),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")plt.tight_layout()
plt.show()# =================================
# 4. 可视化网格搜索结果
# =================================
print("\n4. 网格搜索结果可视化")
# 提取结果
results = grid_search.cv_results_
mean_scores = results['mean_test_score'].reshape(len(param_grid['n_neighbors']), len(param_grid['weights']), len(param_grid['p']))# 创建热力图
fig, axes = plt.subplots(1, 2, figsize=(15, 6), sharey=True)# 曼哈顿距离 (p=1)
ax1 = axes[0]
im1 = ax1.imshow(mean_scores[:, :, 0].T, aspect='auto', cmap='viridis')
ax1.set_title('曼哈顿距离 (p=1)')
ax1.set_xlabel('k值')
ax1.set_ylabel('权重类型')
ax1.set_xticks(np.arange(len(param_grid['n_neighbors'])))
ax1.set_xticklabels(param_grid['n_neighbors'])
ax1.set_yticks([0, 1])
ax1.set_yticklabels(param_grid['weights'])
fig.colorbar(im1, ax=ax1, label='准确率')# 欧氏距离 (p=2)
ax2 = axes[1]
im2 = ax2.imshow(mean_scores[:, :, 1].T, aspect='auto', cmap='viridis')
ax2.set_title('欧氏距离 (p=2)')
ax2.set_xlabel('k值')
ax2.set_yticks([0, 1])
ax2.set_yticklabels(param_grid['weights'])
ax2.set_xticks(np.arange(len(param_grid['n_neighbors'])))
ax2.set_xticklabels(param_grid['n_neighbors'])
fig.colorbar(im2, ax=ax2, label='准确率')plt.suptitle('网格搜索参数性能热力图', fontsize=16)
plt.tight_layout()
plt.show()