当前位置: 首页 > ops >正文

深度学习与遥感入门(六)|轻量化 MobileNetV2 高光谱分类

系列回顾:
(一)CNN 基础:高光谱图像分类可视化全流程
(二)HybridNet(CNN+Transformer):提升全局感受野
(三)GCN 入门实战:基于光谱 KNN 的图卷积分类与全图预测
(四)空间–光谱联合构图的 GCN:RBF 边权 + 自环 + 早停,得到更稳更自然的全图分类结果
(五)GAT & 构图消融 + 分块全图预测:更稳更快的高光谱图分类(PyTorch Geometric 实战)
合集链接:https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzkwMTE0MjI4NQ==&action=getalbum&album_id=4007114522736459789#wechat_redirect
本篇(六)聚焦“数据泄露”,采用仅训练集像素拟合 StandardScaler+PCA,并在全图预测中共享同一变换空间;模型选用轻量化 MobileNetV2深度可分离卷积,在显存友好的坐标批推理下实现全图预测

0. 前言:PCA 与高光谱分类中的“数据泄露”

  • 什么是泄露? 训练阶段直接/间接使用了测试数据统计信息(均值、方差、主成分方向等)。
  • 怎么产生? 在整图上 fit 标准化与 PCA,然后再切训练/测试或直接做全图分类。
  • 为什么常见? 历史习惯、样本少时稳定性考虑、对比研究图省事、实现方便。
  • 影响大吗? 小数据集上常为 0.1%~1% 的 OA 差异;但在类分布差异大训练样本极少时,差距可达数个百分点。真实部署场景绝不允许整图 fit

本文做法先分层抽样得到训练/测试索引仅用训练像素 fit 标准化与 PCA用该变换对整图 transform训练与预测均在同一(训练集拟合得到的)特征空间,从源头避免泄露。

1. 任务要点

  1. 严格无泄露预处理:只用训练像素拟合 StandardScaler+PCA;全图在同一变换空间中变换。
  2. 轻量模型:用 MobileNetV2 的深度可分离卷积(3 段)+ GAP + FC。
  3. 全图预测显存友好:按坐标批收集 patch → 堆成 batch → 前向推理。
  4. 评估classification_report、混淆矩阵、OA;可视化支持 Windows 阻塞显示。

2. 方法详解

2.1 严格无泄露的 PCA 流程

  • 先划分后拟合:对有标签像素做分层抽样得到训练/测试索引;仅训练像素拟合 StandardScalerPCA
  • 全图共享空间:将整图 (H×W×Bands) 用训练集拟合的变换进行标准化与降维,得到 (H×W×PCA_DIM)
  • 提取 patch:在 PCA 空间内按坐标提取 (PATCH_SIZE×PATCH_SIZE×PCA_DIM) 的 patch 作为输入。

这样做的关键测试像素从未参与统计,评估更可信。

2.2 轻量化 MobileNetV2(HSI 版)

  • Depthwise Separable Conv:逐通道 3×3 深度卷积 + 1×1 点卷积,大幅降参与算力需求。
  • 网络骨干:3 段深度可分离卷积 → GAP(自适应全局平均池化)→ 全连接输出。
  • 输入通道:这里输入为 PCA 后的通道数(例如 30),以二维 patch 形式输入(C×H×W)。

2.3 全图预测策略(坐标批)

  • 坐标遍历:生成所有像素坐标。
  • 反射填充:边界像素也能提取完整 patch。
  • 批量收集:按 batch_size 组装 patch → 前向 → 填回 pred_map
  • 显存稳定:避免一次性张量过大导致溢出。

3. 代码逐段 + 解释

下面先按逻辑分段展示与解释;最末提供“一键可跑脚本(整合版)”,复制后仅需修改数据路径即可运行。

3.1 全局与可视化设置

import os, time, numpy as np, scipy.io as sio
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib# Windows 下 TkAgg 更稳;Linux/服务器用 Agg(无显示)
if os.name == 'nt':matplotlib.use('TkAgg')
else:matplotlib.use('Agg')import matplotlib.pyplot as plt
import seaborn as snsmatplotlib.rcParams['font.family'] = 'SimHei'
matplotlib.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 120
sns.set_theme(context="notebook", style="whitegrid", font="SimHei")torch.backends.cudnn.benchmark = True

3.2 随机种子

def set_seeds(seed=42):import randomrandom.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)

固定随机性,保证复现。

3.3 轻量化网络

class DepthwiseSeparableConv(nn.Module):def __init__(self, in_ch, out_ch, stride=1):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False)self.pointwise = nn.Conv2d(in_ch, out_ch, 1, bias=False)self.bn = nn.BatchNorm2d(out_ch)self.act = nn.ReLU6(inplace=True)def forward(self, x):x = self.depthwise(x)x = self.pointwise(x)x = self.bn(x)return self.act(x)class MobileNetV2_HSI(nn.Module):def __init__(self, in_ch, num_classes, width_mult=1.0):super().__init__()c1, c2, c3 = int(32 * width_mult), int(64 * width_mult), int(128 * width_mult)self.layer1 = DepthwiseSeparableConv(in_ch, c1)self.layer2 = DepthwiseSeparableConv(c1, c2)self.layer3 = DepthwiseSeparableConv(c2, c3)self.gap = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(c3, num_classes)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.gap(x).flatten(1)return self.fc(x)

结构极简但实用:3 段深度可分离卷积 + GAP + FC,对 HSI 的小样本/低算力场景友好。

3.4 数据集与 Patch 封装

class HSIPatchDataset(Dataset):def __init__(self, patches, labels):# patches: (N, H, W, C) → 张量 (N, C, H, W)self.X = torch.tensor(patches, dtype=torch.float32).permute(0, 3, 1, 2)self.y = torch.tensor(labels, dtype=torch.long)def __len__(self): return len(self.y)def __getitem__(self, idx): return self.X[idx], self.y[idx]

3.5 全图预测(坐标批推理)

@torch.inference_mode()
def predict_full_image_by_coords(model, X_img_pca, patch_size, device,batch_size=2048, title="全图预测(坐标批推理)"):model.eval()H, W, C = X_img_pca.shapem = patch_size // 2padded = np.pad(X_img_pca, ((m, m), (m, m), (0, 0)), mode='reflect')coords = np.mgrid[0:H, 0:W].reshape(2, -1).Tpred_map = np.zeros((H, W), dtype=np.int32)t0 = time.time()for i in range(0, len(coords), batch_size):batch_coords = coords[i:i + batch_size]patches = np.empty((len(batch_coords), patch_size, patch_size, C), dtype=np.float32)for k, (r, c) in enumerate(batch_coords):patches[k] = padded[r:r + patch_size, c:c + patch_size, :]tensor = torch.from_numpy(patches).permute(0, 3, 1, 2).to(device)preds = model(tensor).argmax(dim=1).cpu().numpy() + 1  # +1 便于和 GT 对齐for (r, c), p in zip(batch_coords, preds):pred_map[r, c] = pprint(f"全图预测耗时:{time.time() - t0:.2f} 秒")# 可视化(阻塞显示,避免“最后的图没有显示”)try:plt.figure(figsize=(10, 7.5))cmap = matplotlib.colormaps.get_cmap('tab20')vmin, vmax = pred_map.min(), pred_map.max()if vmin == vmax: vmin, vmax = 0, 1im = plt.imshow(pred_map, cmap=cmap, interpolation='nearest', vmin=vmin, vmax=vmax)cbar = plt.colorbar(im, shrink=0.85); cbar.set_label('预测类别', rotation=90)plt.title(title, fontsize=14, weight='bold'); plt.axis('off'); plt.tight_layout()print("尝试显示全图预测结果...")plt.show(block=True)except Exception as e:print(f"显示全图预测结果时出错: {e}")try:plt.savefig("prediction_map.png", bbox_inches='tight')print("已保存为 prediction_map.png")except Exception as se:print(f"保存失败: {se}")return pred_map

3.6 主流程(数据→划分→无泄露预处理→训练→评估→全图预测)

下面是主函数的关键片段(末尾附完整可运行脚本):

def main():set_seeds(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"使用设备: {device}")# ---- 路径与超参(按需修改)----DATA_DIR = r"your_path"X_FILE, Y_FILE = "KSC.mat", "KSC_gt.mat"PCA_DIM, PATCH_SIZE, TRAIN_RATIO = 30, 5, 0.30BATCH_SIZE, EPOCHS, LR, WEIGHT_DECAY = 64, 30, 1e-3, 1e-4NUM_WORKERS = 0 if os.name == 'nt' else min(4, os.cpu_count() or 0)PIN_MEMORY = (device.type == 'cuda')PREDICT_BATCH_SIZE = 4096# ---- 读取数据 ----def load_data():X = sio.loadmat(os.path.join(DATA_DIR, X_FILE))Y = sio.loadmat(os.path.join(DATA_DIR, Y_FILE))x_key = [k for k in X.keys() if not k.startswith("__")][0]y_key = [k for k in Y.keys() if not k.startswith("__")][0]return X[x_key], Y[y_key]X_img, Y_img = load_data()h, w, bands = X_img.shapeprint(f"数据尺寸: {h}×{w}, 波段: {bands}")# ---- 有标签索引 + 分层划分 ----labeled_idx_rc = np.array([(i, j) for i in range(h) for j in range(w) if Y_img[i, j] != 0])labels_all = np.array([Y_img[i, j] - 1 for i, j in labeled_idx_rc], dtype=np.int64)num_classes = len(np.unique(labels_all))print(f"有标签样本: {len(labeled_idx_rc)},类别数: {num_classes}")train_ids, test_ids = train_test_split(np.arange(len(labeled_idx_rc)),test_size=1 - TRAIN_RATIO, stratify=labels_all, random_state=42)# ---- 仅训练像素拟合 Scaler+PCA(无泄露)----print("拟合 StandardScaler/PCA(仅训练像素)...")train_pixels = np.array([X_img[i, j] for i, j in labeled_idx_rc[train_ids]], dtype=np.float32)scaler = StandardScaler().fit(train_pixels)pca = PCA(n_components=PCA_DIM, random_state=42).fit(scaler.transform(train_pixels))# 整图进入同一空间(float32)X_pca_img = pca.transform(scaler.transform(X_img.reshape(-1, bands).astype(np.float32))).astype(np.float32)X_pca_img = X_pca_img.reshape(h, w, PCA_DIM)# ---- 提取训练/测试 patch ----def extract_patches(sel_ids):m = PATCH_SIZE // 2padded = np.pad(X_pca_img, ((m, m), (m, m), (0, 0)), mode='reflect')patches = np.empty((len(sel_ids), PATCH_SIZE, PATCH_SIZE, PCA_DIM), dtype=np.float32)labs = np.empty((len(sel_ids),), dtype=np.int64)for n, k in enumerate(sel_ids):i, j = labeled_idx_rc[k]patches[n] = padded[i:i + PATCH_SIZE, j:j + PATCH_SIZE, :]labs[n] = labels_all[k]return patches, labsX_train, y_train = extract_patches(train_ids)X_test, y_test = extract_patches(test_ids)# ---- DataLoader ----train_loader = DataLoader(HSIPatchDataset(X_train, y_train), batch_size=BATCH_SIZE,shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)test_loader = DataLoader(HSIPatchDataset(X_test, y_test), batch_size=BATCH_SIZE,shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)# ---- 模型与优化器 ----model = MobileNetV2_HSI(PCA_DIM, num_classes).to(device)optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)criterion = nn.CrossEntropyLoss()scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)# ---- 评估函数 ----@torch.no_grad()def evaluate(loader):model.eval()all_y, all_pred = [], []for xb, yb in loader:xb = xb.to(device)pred = model(xb).argmax(dim=1).cpu().numpy()all_pred.extend(pred); all_y.extend(yb.numpy())return accuracy_score(all_y, all_pred), np.array(all_y), np.array(all_pred)# ---- 训练循环 ----print("开始训练...")best_acc, model_path = 0.0, "best_mnv2_hsi.pth"for epoch in range(1, EPOCHS + 1):model.train(); total_loss = 0.0for xb, yb in train_loader:xb, yb = xb.to(device), yb.to(device)optimizer.zero_grad()loss = criterion(model(xb), yb)loss.backward(); optimizer.step()total_loss += loss.item() * xb.size(0)test_acc, _, _ = evaluate(test_loader)scheduler.step(test_acc)print(f"Epoch {epoch:02d}/{EPOCHS} | 损失: {total_loss/len(train_loader.dataset):.4f} | 测试准确率: {test_acc:.4f}")if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), model_path)print(f"训练完成,最佳测试准确率:{best_acc:.4f}")# ---- 安全加载最佳权重 ----try:state = torch.load(model_path, map_location=device, weights_only=True)except TypeError:state = torch.load(model_path, map_location=device)model.load_state_dict(state)# ---- 测试报告 & 混淆矩阵 ----test_acc, y_true, y_pred = evaluate(test_loader)print("\n测试集分类报告:")print(classification_report(y_true, y_pred, digits=4, zero_division=0))plt.figure(figsize=(10, 7))class_names = [f"类{i + 1}" for i in range(num_classes)]sns.heatmap(confusion_matrix(y_true, y_pred),annot=True, fmt='d', cmap="Blues",xticklabels=class_names, yticklabels=class_names,cbar=False, square=True)plt.xlabel("预测标签"); plt.ylabel("真实标签")plt.title("MobileNetV2 测试集混淆矩阵", fontsize=14, weight='bold')plt.tight_layout(); plt.show(block=True)# ---- 全图预测 ----print("全图预测中(坐标→收集 patch→堆成 batch→前向)...")pred_map = predict_full_image_by_coords(model, X_pca_img, patch_size=PATCH_SIZE, device=device,batch_size=PREDICT_BATCH_SIZE, title="MobileNetV2 全图预测(坐标批推理)")print(f"预测图统计: min={pred_map.min()}, max={pred_map.max()}, mean={pred_map.mean():.3f}")print("完成。")

3.7 Windows 入口保护(多进程/显示更稳)

if __name__ == "__main__":try:import multiprocessing as mpmp.set_start_method("spawn", force=True)mp.freeze_support()except Exception:passmain()

4. 结果展示

在这里插入图片描述
在这里插入图片描述
欢迎大家关注下方我的公众获取更多内容!

http://www.xdnf.cn/news/17563.html

相关文章:

  • UNet改进(32):结合CNN局部建模与Transformer全局感知
  • HTTP应用层协议-长连接
  • (25.08)Ubuntu20.04+ROS1复现LIO-SAM
  • 2025年最新原创多目标算法:多目标酶作用优化算法(MOEAO)求解MaF1-MaF15及工程应用---盘式制动器设计,提供完整MATLAB代码
  • 【代码随想录day 18】 力扣 501.二叉搜索树中的众数
  • 力扣热题100------279.完全平方数
  • 吉利汽车7月销量超23.7万辆 同比增长58%
  • 【嵌入式C语言】
  • 【10】微网优联——微网优联 嵌入式技术一面,校招,面试问答记录
  • 数据结构:串、数组与广义表
  • IP分片(IP Fragmentation)
  • 力扣109:有序链表转换二叉搜索树
  • docter的使用、vscode(cursor)和docker的连接,详细分析说明
  • 【3D Gen 入坑(1)】Hunyuan3D-Paint 2.1 安装 `custom_rasterizer` 报错完整排查
  • 面试题-----RabbitMQ
  • MySQL的索引(索引的数据结构-B+树索引):
  • 嵌入式Linnux学习 -- 软件编程2
  • 【已解决】报错:WARNING: pip is configured with locations that require TLS/SSL
  • STM32——system文件夹
  • 【ros-humble】4.C++写法巡场海龟(服务通讯)
  • Spring Boot 中 @Transactional 解析
  • [Oracle] UNPIVOT 列转行
  • Linux kernel network stack, some good article
  • Day 37:早停策略和模型权重的保存
  • 《番外:Veda的备份,在某个未联网的旧服务器中苏醒……》
  • Mybatis学习之缓存(九)
  • 从零开始的云计算生活——第四十一天,勇攀高峰,Kubernetes模块之单Master集群部署
  • Seata
  • vue+django 大模型心理学智能诊断评测系统干预治疗辅助系统、智慧心理医疗、带知识图谱
  • EXISTS 替代 IN 的性能优化技巧