当前位置: 首页 > backend >正文

遥感amp;机器学习入门实战教程 | Sklearn 案例③:PCA + SVM / 随机森林 对比与调参

前两篇我们完成了 无泄露 PCA 预处理k-NN 分类。链接,https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=4119240160615596034#wechat_redirect
本篇引入更强的 支持向量机 (SVM)随机森林 (RF),利用交叉验证自动调参,形成可靠的基线模型,并增加整图预测与参数可视化。

🎯 本文目标

  1. 保持严格无数据泄露:StandardScalerPCA 仅在训练集上 fit
  2. 使用 GridSearchCV 在训练集上做超参数搜索。
  3. 输出完整评估指标:OA / AA / Kappa + 分类报告。
  4. 绘制混淆矩阵(新配色)与 PCA 累计解释方差曲线。
  5. 可视化参数影响:SVM 的 C×gamma 热力图,RF 的 n_estimators×max_depth 热力图。
  6. 整图预测:不遮挡未知区域,完整渲染。

📂 数据准备

与前篇相同:

  • KSC.mat:高光谱数据 (H, W, B)
  • KSC_gt.mat:标签图 (H, W),其中 0=背景,1…C 为类别
your_path/├─ KSC.mat└─ KSC_gt.mat

只需修改脚本中的 DATA_DIR = r"your_path"

① 环境与依赖

import os, time, json
import numpy as np
import scipy.io as sio
import matplotlib
import matplotlib.pyplot as pltfrom sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, cohen_kappa_score

说明:导入 Numpy/Scipy 做数据处理,Matplotlib 绘图,Scikit-learn 提供 PCA、SVM、RF 与评估工具。

② 参数设置

DATA_DIR = r"your_path"   # 修改为你的数据目录
PCA_DIM = 30
TRAIN_RATIO = 0.3
SEED = 42
N_FOLDS = 5

说明:这里设置 PCA 主成分数=30,训练集占比=30%,并固定随机种子保证可复现。

③ 加载与划分数据

X = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32)
Y = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int)
h, w, b = X.shapecoords = np.argwhere(Y != 0)              # 有标签像素坐标
labels = Y[coords[:, 0], coords[:, 1]]-1  # 转为 0-based

说明:只取有标签像素做监督学习,避免背景 0 干扰。

④ 无泄露预处理:仅用训练像素 fit

from sklearn.model_selection import train_test_splittrain_ids, test_ids = train_test_split(np.arange(len(coords)), train_size=TRAIN_RATIO,stratify=labels, random_state=SEED)train_pixels = X[coords[train_ids,0], coords[train_ids,1]]scaler = StandardScaler().fit(train_pixels)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(train_pixels))

说明

  • fit 仅用训练像素:防止测试信息泄露。
  • 顺序:StandardScalerPCA

⑤ 整图统一变换

X_flat = X.reshape(-1, b)
X_std  = scaler.transform(X_flat)
X_pca_flat = pca.transform(X_std)
X_pca = X_pca_flat.reshape(h, w, PCA_DIM)X_train = X_pca[coords[train_ids,0], coords[train_ids,1]]
y_train = labels[train_ids]
X_test  = X_pca[coords[test_ids,0],  coords[test_ids,1]]
y_test  = labels[test_ids]

说明:整幅影像使用相同参数变换;之后提取训练/测试像素用于建模。

⑥ PCA 累计解释方差曲线

cum_var = np.cumsum(pca.explained_variance_ratio_)
plt.plot(np.arange(1, len(cum_var)+1), cum_var, marker='o')
plt.axhline(0.95, ls='--', label="95% 阈值")
plt.axvline(PCA_DIM, ls='--', label=f"n={PCA_DIM}")
plt.xlabel("主成分数"); plt.ylabel("累计解释方差比")
plt.title("PCA累计解释方差曲线"); plt.legend(); plt.show()

说明:辅助判断保留多少维合适(常取解释方差 >95%)。

⑦ 模型与网格搜索

skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)svm_grid = GridSearchCV(SVC(kernel="rbf"),param_grid={"C":[1,5,10,20,50,100],"gamma":["scale",0.01,0.005,0.001]},cv=skf, n_jobs=-1).fit(X_train, y_train)rf_grid = GridSearchCV(RandomForestClassifier(random_state=SEED,n_jobs=-1),param_grid={"n_estimators":[200,400,800],"max_depth":[None,10,20,40]},cv=skf, n_jobs=-1).fit(X_train, y_train)

说明:使用分层K折交叉验证,在训练集上搜索最优参数。

⑧ 参数可视化

SVM:C×gamma 热力图

Cs=[1,5,10,20,50,100]; Gs=["scale",0.01,0.005,0.001]
Z=np.full((len(Gs),len(Cs)),np.nan)
for mean,params in zip(svm_grid.cv_results_["mean_test_score"], svm_grid.cv_results_["params"]):i,j = Gs.index(str(params["gamma"])), Cs.index(params["C"])Z[i,j]=mean
plt.imshow(Z,cmap="viridis"); plt.colorbar(label="CV准确率")
plt.xticks(range(len(Cs)),Cs); plt.yticks(range(len(Gs)),Gs)
plt.title("SVM参数热力图"); plt.show()

说明:直观展示不同 Cgamma 的组合对准确率的影响。

⑨ 测试集评估

from sklearn.metrics import classification_reportsvm_best=svm_grid.best_estimator_; rf_best=rf_grid.best_estimator_
y_pred_svm=svm_best.predict(X_test); y_pred_rf=rf_best.predict(X_test)print("SVM最佳参数:",svm_grid.best_params_)
print(classification_report(y_test,y_pred_svm))
print("RF最佳参数:",rf_grid.best_params_)
print(classification_report(y_test,y_pred_rf))

说明:输出每类的 precision/recall/F1,以及整体 OA / Kappa。

⑩ 混淆矩阵(新配色)

cm = confusion_matrix(y_test,y_pred_svm)
plt.imshow(cm,cmap="Blues"); plt.title("SVM混淆矩阵"); plt.colorbar(); plt.show()cmn=cm/cm.sum(axis=1,keepdims=True)
plt.imshow(cmn,cmap="YlGnBu",vmin=0,vmax=1); plt.title("SVM混淆矩阵(归一化)")
plt.colorbar(); plt.show()

说明:两版配色:计数=Blues,归一化=YlGnBu,更直观。

⑪ 整图预测(完整渲染)

best_model=svm_best if accuracy_score(y_test,y_pred_svm)>=accuracy_score(y_test,y_pred_rf) else rf_best
pred_map=best_model.predict(X_pca_flat).reshape(h,w)+1from matplotlib.colors import ListedColormap
cmap=ListedColormap([plt.cm.tab20(i%20) for i in range(len(np.unique(labels)))])
plt.imshow(pred_map,cmap=cmap); plt.title("整图预测结果"); plt.axis("off"); plt.show()

说明:直接对整幅图预测,不遮挡背景区域。

🔚 总结

  • 我们在无泄露前提下,对 SVMRF 做了自动调参。
  • 输出了完整指标、混淆矩阵、新配色图表。
  • 可视化了超参数影响,帮助理解模型行为。
  • 实现了整图预测,方便直观对比。

结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

完整代码:通过网盘分享的文件:Sklearn 案例3.py
链接: https://pan.baidu.com/s/1a4rD0fvjqXBwlz3K9hELkg 提取码: 123z
–来自百度网盘超级会员v6的分享

🔗 下一篇预告

第④篇将引入 1D-CNN 深度学习基线,比较传统机器学习与深度模型在小样本高光谱分类中的差异。

欢迎大家关注下方我的公众获取更多内容!

http://www.xdnf.cn/news/18181.html

相关文章:

  • LAMP架构编译安装部署
  • 垂直领域大模型构建:法律行业“类ChatGPT”系统的训练与落地
  • PythonDay31
  • Vue2+Vue3前端开发_Day1
  • Fragment重要知识点总结
  • Incredibuild 新增 Unity 支持:击破构建时间过长的痛点
  • 机器学习(决策树2)
  • MVVM开源项目
  • Netty处理粘包与拆包
  • vue使用vue-cropper实现图片裁剪之单图裁剪
  • 关于mybatis表关联查询和mybatis-Plus单表查询传入时间查询数据(走索引)
  • Linux Namespace 隔离的“暗面”——故障排查、认知误区与演进蓝图
  • CVPR 2025 | 具身智能 | HOLODECK:一句话召唤3D世界,智能体的“元宇宙练功房”来了
  • 【HTML】3D动态凯旋门
  • 通过C#上位机串口写入和读取浮点数到stm32的片内flash实战4(通过串口下发AD9833设置值并在上位机显示波形曲线)
  • “你不干有的是AI干”,提示词中的“情感化提示”
  • 如何在 Ubuntu Linux 上安装 RPM 软件包
  • 【SQL优化案例】统计信息缺失
  • Vercel v0 iOS版重磅发布:AI驱动的移动开发新篇章
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘paramiko’问题
  • C++入门自学Day14-- Stack和Queue的自实现(适配器)
  • Java高级面试实战:Spring Boot微服务与Redis缓存整合案例解析
  • “R语言+遥感”的水环境综合评价方法实践技术应用
  • Centos7物理安装 Redis8.2.0
  • 【GNSS定位原理及算法杂记6】​​​​​​PPP(精密单点定位)原理,RTK/PPK/PPP区别讨论
  • 【部署相关】DockerKuberbetes常用命令大全(速查+解释)
  • 孩子王披露半年报:多数据持续增长,全年预期增强
  • git仓库和分支的关系
  • Linux GPIO子系统中开漏模式软件仿真机制的深度分析
  • 【深度学习计算性能】06:多GPU的简洁实现