【图像处理基石】如何对遥感图像进行目标检测?
如何对遥感图像进行目标检测?
1. 遥感图像目标检测的基本流程
遥感图像目标检测是从卫星、无人机等遥感影像中自动识别和定位感兴趣目标(如建筑、车辆、机场等)的技术,核心流程包括:
- 数据预处理:辐射校正(消除传感器误差)、几何校正(修正地形/投影偏差)、裁剪/下采样(处理高分辨率数据);
- 特征提取:通过卷积神经网络(CNN)提取图像的纹理、形状、光谱等特征;
- 目标定位与分类:利用检测算法(如锚框机制、Transformer等)预测目标的位置(边界框)和类别;
- 后处理:非极大值抑制(NMS)去除冗余框,提升检测精度。
2. 遥感图像目标检测的难点
与自然图像(如手机拍摄的照片)相比,遥感图像的特殊性带来了独特挑战:
- 目标尺度差异极大:同一幅图像中可能同时存在千米级的机场和米级的车辆,尺度跨度可达1000倍以上;
- 目标方向任意:遥感图像为俯视视角,目标(如车辆、船只)可沿任意方向旋转,轴对齐边框(Axis-Aligned BBox)会引入大量背景噪声;
- 小目标密集分布:如停车场的车辆、城区的小型建筑,往往密集排列且像素占比低(可能仅10×10像素);
- 背景复杂且干扰强:地物(如道路、植被)与目标可能具有相似光谱/纹理特征(如车辆与路面颜色接近);
- 数据标注成本高:遥感图像分辨率高(单幅可达GB级),且专业标注需要领域知识(如区分“飞机”和“直升机”)。
3. 解决方案
针对上述难点,主流技术方案包括:
难点 | 解决方案 |
---|---|
尺度差异大 | 多尺度特征融合(如FPN)、动态锚框生成(根据图像内容自适应调整锚框尺度) |
目标方向任意 | 旋转边框(Rotated BBox)回归(如R2CNN、RRPN)、角度感知的损失函数 |
小目标密集 | 高分辨率特征保留(如CSPNet)、超分辨率重建(提升小目标细节)、密集检测头 |
背景复杂 | 注意力机制(如CBAM)抑制背景噪声、多模态融合(结合光谱/雷达数据) |
标注成本高 | 半监督学习(利用少量标注数据训练)、迁移学习(从自然图像模型迁移权重) |
PyTorch实现遥感图像目标检测(简化版)
以下实现一个支持旋转边框的简化检测模型,基于ResNet50+FPN提取特征,使用旋转锚框预测目标的位置(x, y, w, h, θ)和类别。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import numpy as np
import cv2
import os
from PIL import Image# -------------------------- 1. 数据集定义 --------------------------
class RemoteSensingDataset(Dataset):def __init__(self, img_dir, ann_dir, img_size=(512, 512)):"""遥感数据集初始化:param img_dir: 图像文件夹路径:param ann_dir: 标注文件路径(每个图像对应一个txt,每行格式:x_center y_center w h angle class):param img_size: 图像resize尺寸"""self.img_dir = img_dirself.ann_dir = ann_dirself.img_size = img_sizeself.img_names = [f for f in os.listdir(img_dir) if f.endswith(('png', 'jpg'))]self.transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet均值])def __len__(self):return len(self.img_names)def __getitem__(self, idx):img_name = self.img_names[idx]img_path = os.path.join(self.img_dir, img_name)ann_path = os.path.join(self.ann_dir, img_name.replace('.png', '.txt').replace('.jpg', '.txt'))# 读取图像img = Image.open(img_path).convert('RGB')img = self.transform(img)# 读取标注(旋转框:x_center, y_center, w, h, angle(弧度), class)boxes = []labels = []if os.path.exists(ann_path):with open(ann_path, 'r') as f:for line in f.readlines():xc, yc, w, h, angle, cls = map(float, line.strip().split())# 归一化坐标转绝对坐标xc *= self.img_size[0]yc *= self.img_size[1]w *= self.img_size[0]h *= self.img_size[1]boxes.append([xc, yc, w, h, angle])labels.append(cls)boxes = torch.tensor(boxes, dtype=torch.float32) # (N, 5)labels = torch.tensor(labels, dtype=torch.long) # (N,)return img, boxes, labels# -------------------------- 2. 模型结构 --------------------------
class FPN(nn.Module):"""特征金字塔网络(FPN):融合多尺度特征"""def __init__(self, in_channels_list, out_channels):super(FPN, self).__init__()self.lateral_convs = nn.ModuleList() # 横向卷积(降维到out_channels)self.fpn_convs = nn.ModuleList() # 输出卷积(消除 aliasing effect)for in_channels in in_channels_list:self.lateral_convs.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))self.fpn_convs.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))def forward(self, x):""":param x: 从backbone输出的多尺度特征 [C1, C2, C3, C4](分辨率从高到低):return: 融合后的特征 [P1, P2, P3, P4](同输入尺度)"""# 横向连接 + 上采样laterals = [lateral_conv(xi) for lateral_conv, xi in zip(self.lateral_convs, x)]# 从最高层开始融合outs = [laterals[-1]]for i in range(len(laterals)-2, -1, -1):# 上采样高层特征并与当前层融合upsample = F.interpolate(outs[-1], size=laterals[i].shape[2:], mode='bilinear', align_corners=True)outs.append(laterals[i] + upsample)# 反转顺序(从低层到高层)outs = outs[::-1]# 输出卷积outs = [fpn_conv(out) for fpn_conv, out in zip(self.fpn_convs, outs)]return outsclass RotatedDetectionHead(nn.Module):"""旋转框检测头:预测类别和旋转框参数(x, y, w, h, θ)"""def __init__(self, in_channels, num_classes, num_anchors=9):super(RotatedDetectionHead, self).__init__()# 分类头(每个锚框对应num_classes个类别)self.cls_head = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1))# 回归头(每个锚框对应5个参数:x, y, w, h, θ)self.reg_head = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(in_channels, num_anchors * 5, kernel_size=3, padding=1))self.num_classes = num_classesself.num_anchors = num_anchorsdef forward(self, x):""":param x: FPN输出的多尺度特征 [P1, P2, P3, P4]:return: 分类预测和回归预测(按特征尺度拼接)"""cls_preds = []reg_preds = []for feat in x:cls = self.cls_head(feat) # (B, num_anchors*num_classes, H, W)reg = self.reg_head(feat) # (B, num_anchors*5, H, W)# 维度调整:(B, H*W*num_anchors, num_classes) 和 (B, H*W*num_anchors, 5)cls = cls.permute(0, 2, 3, 1).contiguous().view(cls.shape[0], -1, self.num_classes)reg = reg.permute(0, 2, 3, 1).contiguous().view(reg.shape[0], -1, 5)cls_preds.append(cls)reg_preds.append(reg)return torch.cat(cls_preds, dim=1), torch.cat(reg_preds, dim=1)class RemoteSensingDetector(nn.Module):def __init__(self, num_classes):super(RemoteSensingDetector, self).__init__()# Backbone:ResNet50(取前4个stage的输出作为FPN输入)self.backbone = models.resnet50(pretrained=True)self.backbone_features = nn.ModuleList([self.backbone.conv1, self.backbone.bn1, self.backbone.relu, # C1 (1/2)self.backbone.maxpool, self.backbone.layer1, # C2 (1/4)self.backbone.layer2, # C3 (1/8)self.backbone.layer3 # C4 (1/16)])# FPN输入通道(ResNet50各stage输出通道)self.fpn = FPN(in_channels_list=[256, 512, 1024, 2048], out_channels=256)# 检测头self.detection_head = RotatedDetectionHead(in_channels=256, num_classes=num_classes)def forward(self, x):# Backbone特征提取feats = []for layer in self.backbone_features:x = layer(x)if isinstance(layer, nn.Sequential): # 取layer1~layer3的输出feats.append(x)# FPN融合fpn_feats = self.fpn(feats)# 检测头预测cls_pred, reg_pred = self.detection_head(fpn_feats)return cls_pred, reg_pred# -------------------------- 3. 损失函数 --------------------------
class RotatedLoss(nn.Module):def __init__(self, cls_weight=1.0, reg_weight=5.0):super(RotatedLoss, self).__init__()self.cls_weight = cls_weightself.reg_weight = reg_weightdef forward(self, cls_pred, reg_pred, labels, boxes, anchors):""":param cls_pred: 分类预测 (B, N_anchors, num_classes):param reg_pred: 回归预测 (B, N_anchors, 5):param labels: 真实类别 (B, N_boxes):param boxes: 真实旋转框 (B, N_boxes, 5):param anchors: 锚框 (N_anchors, 5):return: 总损失"""# 简化版:假设已通过IOU匹配锚框与真实框,这里直接计算正样本损失# 实际中需要先进行锚框匹配(如MaxIOU匹配)pos_mask = ... # 正样本掩码(简化,实际需实现)num_pos = pos_mask.sum()# 分类损失(仅正样本)cls_loss = F.cross_entropy(cls_pred[pos_mask], labels.repeat_interleave(num_pos//labels.shape[0]) # 简化,实际需对应标签)# 回归损失:Smooth L1(坐标+宽高) + 角度周期性损失reg_target = self.anchor2target(boxes, anchors[pos_mask]) # 计算锚框到真实框的偏移reg_loss = F.smooth_l1_loss(reg_pred[pos_mask, :4], reg_target[:, :4]) # 坐标+宽高损失# 角度损失(考虑周期性:angle ∈ [-π/2, π/2],使用sin/cos转换)angle_pred = reg_pred[pos_mask, 4]angle_target = reg_target[:, 4]angle_loss = 1 - torch.mean(torch.cos(angle_pred - angle_target) # 余弦损失(角度差越小,损失越小))total_loss = self.cls_weight * cls_loss + self.reg_weight * (reg_loss + angle_loss)return total_lossdef anchor2target(self, boxes, anchors):"""将真实框转换为相对于锚框的偏移量(简化版)"""# 实际中需根据锚框计算dx, dy, dw, dh, dθreturn boxes - anchors # 简化,实际需更复杂的转换# -------------------------- 4. 训练与推理示例 --------------------------
if __name__ == "__main__":# 配置num_classes = 5 # 假设5类目标(如建筑、车辆、机场等)img_dir = "path/to/remote_sensing/images" # 图像路径ann_dir = "path/to/remote_sensing/annotations" # 标注路径batch_size = 2epochs = 10lr = 1e-4# 数据集与加载器dataset = RemoteSensingDataset(img_dir, ann_dir)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)# 模型、损失函数、优化器model = RemoteSensingDetector(num_classes=num_classes)criterion = RotatedLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)# 简化训练循环model.train()for epoch in range(epochs):total_loss = 0.0for imgs, boxes, labels in dataloader:optimizer.zero_grad()cls_pred, reg_pred = model(imgs)# 生成锚框(简化版:假设已实现锚框生成函数)anchors = torch.randn(1000, 5) # 示例锚框,实际需根据特征图生成loss = criterion(cls_pred, reg_pred, labels, boxes, anchors)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")# 推理示例model.eval()with torch.no_grad():img, _, _ = dataset[0]img = img.unsqueeze(0) # 加batch维度cls_pred, reg_pred = model(img)print("预测类别概率:", cls_pred.softmax(dim=-1)[0, :5]) # 前5个锚框的类别概率print("预测旋转框参数:", reg_pred[0, :5]) # 前5个锚框的旋转框参数
核心实现步骤
- 数据集定义:处理遥感图像和旋转框标注(格式:[x_center, y_center, width, height, angle, class]);
- 模型结构:Backbone(ResNet50)+ Neck(FPN)+ Head(分类头+旋转框回归头);
- 损失函数:分类损失(交叉熵)+ 旋转框回归损失(Smooth L1 + 角度周期性损失);
- 训练与推理:简化的训练循环和推理逻辑。
代码说明
- 数据集:假设标注文件为txt格式,每行包含旋转框参数(中心坐标、宽高、角度)和类别,通过
RemoteSensingDataset
类加载并预处理。 - 模型:
- Backbone使用ResNet50提取多尺度特征;
- FPN融合不同分辨率特征,缓解尺度差异问题;
- 检测头预测目标类别和旋转框参数(支持任意方向)。
- 损失函数:分类损失用交叉熵,回归损失结合Smooth L1(坐标/宽高)和余弦损失(角度周期性)。
进一步优化方向
- 实现完整的锚框匹配机制(如MaxIOU)和非极大值抑制(NMS) 处理旋转框;
- 加入注意力机制(如SE模块)增强目标特征;
- 使用数据增强(如随机旋转、缩放、噪声添加)提升模型鲁棒性;
- 迁移预训练权重(如从COCO数据集迁移)加速收敛。