基于U-NET遥感影像语义分割任务快速上手
目录
核心步骤概览
第一步:准备标注数据 (这是最耗时但最关键的一步)
第二步:搭建数据集和数据加载器
第三步:构建 U-Net 模型
第四步:编写训练脚本 (train.py)
第五步:训练模型
第六步:对新影像进行预测 (predict.py)
核心步骤概览
- 准备标注数据 (最关键!)
- 搭建数据集和数据加载器
- 构建 U-Net 模型
- 编写训练脚本
- 训练模型
- 对新影像进行预测
第一步:准备标注数据 (这是最耗时但最关键的一步)
只有影像,但没有标注,模型是无法学习的。需要为这些 TIF 影像创建对应的像素级标注掩码 (Mask)。
- 标注工具:
- QGIS: 免费、强大的开源地理信息系统软件。可以加载 TIF 影像,然后创建新的矢量图层(多边形),手动绘制建筑物、水系、工程车辆的边界。绘制完成后,需要将矢量图层栅格化 (Rasterize) 成与原始影像分辨率、范围完全一致的 GeoTIFF 或 PNG 文件。这是最专业但也最耗时的方法。
- Labelbox / Supervisely / VGG Image Annotator (VIA): 这些是在线或桌面的图像标注平台,专门为机器学习设计。它们通常提供多边形、多边形套索等工具,操作比 QGIS 更直观。标注完成后,导出为 PNG 格式的掩码文件(每个像素值代表一个类别)。
- ArcGIS: 商业软件,功能强大,类似 QGIS。
- 标注规范 (非常重要!):
- 定义类别: 明确类别 ID。
- 0: 背景 (Background)
- 1: 建筑物 (Building)
- 2: 水系 (Water)
- 3: 工程车辆 (Construction Vehicle)
- 掩码格式: 推荐使用 单通道 PNG 文件。文件名应与原始 TIF 影像对应(如 image_001.tif 对应 image_001_mask.png)。
- 精度: 尽量精确地描绘边界,尤其是建筑物的直角和水系的蜿蜒轮廓。
- 工程车辆: 由于目标小,标注时要特别仔细,确保不遗漏。
- 定义类别: 明确类别 ID。
- 数据集划分:
将标注好的数据划分为:
-
- 训练集 (Training Set): ~70-80% 的数据,用于训练模型。
- 验证集 (Validation Set): ~10-15% 的数据,用于在训练过程中监控模型性能,防止过拟合。
- 测试集 (Test Set): ~10-15% 的数据,用于最终评估模型性能,在整个训练过程中绝对不能使用。
第二步:搭建数据集和数据加载器
创建一个 Python 脚本 dataset.py。
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2class RemoteSensingDataset(Dataset):def __init__(self, image_dir, mask_dir, transform=None):"""Args:image_dir (str): 存放原始遥感影像 (.tif) 的目录路径。mask_dir (str): 存放标注掩码 (.png) 的目录路径。transform (callable, optional): 数据增强和预处理的转换函数。"""self.image_dir = image_dirself.mask_dir = mask_dirself.transform = transform# 获取所有影像文件名 (假设 .tif 和 .png 同名)self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.tiff'))]def __len__(self):return len(self.images)def __getitem__(self, idx):# 获取文件名img_name = self.images[idx]mask_name = img_name.replace('.tif', '.png').replace('.tiff', '.png') # 假设掩码是png# 构建完整路径img_path = os.path.join(self.image_dir, img_name)mask_path = os.path.join(self.mask_dir, mask_name)# 加载影像 (PIL Image)# 注意:TIF 可能有多个波段,这里假设是标准的 RGB 3波段image = Image.open(img_path).convert("RGB") # 转为 RGB 模式image = np.array(image) # 转为 numpy array (H, W, C)# 加载掩码 (PIL Image)mask = Image.open(mask_path)mask = np.array(mask) # 转为 numpy array (H, W) 单通道# 确保 mask 的值是 0, 1, 2, 3 (你的类别ID)# 应用数据增强和预处理if self.transform is not None:# Albumentations 的 transform 接受字典 {'image': image, 'mask': mask}transformed = self.transform(image=image, mask=mask)image = transformed['image']mask = transformed['mask']else:# 如果没有 transform,手动进行基本处理image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0mask = torch.from_numpy(mask).long()return image, mask# --- 定义数据增强和预处理 ---
# 强烈建议使用数据增强来提高模型鲁棒性
def get_transforms(train=True):if train:return A.Compose([A.Resize(512, 512), # U-Net 通常需要固定尺寸输入,或使用能处理任意尺寸的变体A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.5),A.RandomRotate90(p=0.5),A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),# A.GaussNoise(var_limit=(10.0, 50.0), p=0.2), # 可选,模拟噪声A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # ImageNet 标准化,常用ToTensorV2(), # 将 numpy array 转为 torch tensor, 并归一化到 [0,1]])else:return A.Compose([A.Resize(512, 512),A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),ToTensorV2(),])# --- 创建数据加载器 ---
# 假设你的数据目录结构如下:
# data/
# ├── train/
# │ ├── images/
# │ └── masks/
# ├── val/
# │ ├── images/
# │ └── masks/
# └── test/
# ├── images/
# └── masks/train_dataset = RemoteSensingDataset(image_dir="data/train/images",mask_dir="data/train/masks",transform=get_transforms(train=True)
)val_dataset = RemoteSensingDataset(image_dir="data/val/images",mask_dir="data/val/masks",transform=get_transforms(train=False) # 验证集通常不做强增强
)# DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
安装依赖:
pip install pillow numpy albumentations
第三步:构建 U-Net 模型
创建 model.py。我们可以自己实现一个简单的 U-Net,但更推荐使用 segmentation_models_pytorch。
安装依赖
pip install segmentation-models-pytorch
生成model.py文件
# model.py
import torch
import torch.nn as nn
import segmentation_models_pytorch as smpdef create_unet_model(num_classes=4, encoder_name='resnet34', encoder_weights='imagenet'):"""使用 smp 库创建 U-Net 模型。Args:num_classes: 分类数 (4: 背景, 建筑物, 水系, 工程车辆)encoder_name: 骨干网络名称,如 'resnet34', 'resnet50', 'efficientnet-b0' 等。encoder_weights: 预训练权重,'imagenet' 表示使用在 ImageNet 上预训练的权重。Returns:PyTorch 模型"""model = smp.Unet(encoder_name=encoder_name,encoder_weights=encoder_weights,in_channels=3, # 输入是 RGB 3波段classes=num_classes,activation=None # 让损失函数 (如 CrossEntropyLoss) 处理)return model# --- 创建模型 ---
model = create_unet_model(num_classes=4, encoder_name='resnet34', encoder_weights='imagenet')
第四步:编写训练脚本 (train.py)
# train.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os
from dataset import RemoteSensingDataset, get_transforms, train_loader, val_loader # 假设 dataset.py 已定义
from model import create_unet_model # 或 UNet# --- 1. 设备 ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# --- 2. 超参数 ---
num_classes = 4
lr = 1e-4
batch_size = 8 # 根据你的显存调整,MPS 可能支持 8-16
num_epochs = 100
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)# --- 3. 模型、损失、优化器 ---
model = create_unet_model(num_classes=num_classes).to(device)
# model = UNet(n_classes=num_classes).to(device) # 如果使用手动实现criterion = nn.CrossEntropyLoss() # 适用于多类别分割
optimizer = Adam(model.parameters(), lr=lr)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True) # 学习率调度# --- 4. 训练循环 ---
best_val_loss = float('inf')for epoch in range(num_epochs):# --- 训练阶段 ---model.train()train_loss = 0.0for images, masks in train_loader:images = images.to(device)masks = masks.to(device) # [B, H, W]optimizer.zero_grad()outputs = model(images) # [B, num_classes, H, W]loss = criterion(outputs, masks)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)# --- 验证阶段 ---model.eval()val_loss = 0.0with torch.no_grad():for images, masks in val_loader:images = images.to(device)masks = masks.to(device)outputs = model(images)loss = criterion(outputs, masks)val_loss += loss.item()val_loss /= len(val_loader)print(f"Epoch [{epoch+1}/{num_epochs}] "f"Train Loss: {train_loss:.4f} "f"Val Loss: {val_loss:.4f}")# --- 学习率调度 ---scheduler.step(val_loss)# --- 保存最佳模型 ---if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pth"))print(f" --> Best model saved at epoch {epoch+1}")# 保存每个 epoch 的模型 (可选)# torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth"))print("Training completed! Best model saved as 'best_model.pth'")
第五步:训练模型
- 组织好数据目录,确保 data/train/images, data/train/masks 等路径正确。
- 运行训练脚本:
python train.py
- 监控训练过程,观察训练损失和验证损失是否下降。如果验证损失不再下降甚至上升,说明可能过拟合。
第六步:对新影像进行预测 (predict.py)
# predict.py
import torch
import numpy as np
from PIL import Image
import os
from model import create_unet_model
from albumentations import Compose, Resize, Normalize, ToTensorV2
from albumentations.pytorch import ToTensorV2def load_model(model_path, num_classes=4, device='cpu'):model = create_unet_model(num_classes=num_classes)model.load_state_dict(torch.load(model_path, map_location=device))model.to(device)model.eval()return modeldef preprocess_image(image_path, transform):image = Image.open(image_path).convert("RGB")image = np.array(image)# 应用与训练时相同的 transform (Resize, Normalize, ToTensor)transformed = transform(image=image)image = transformed['image'].unsqueeze(0) # 添加 batch 维度 [1, C, H, W]return imagedef postprocess_mask(mask_tensor):# mask_tensor shape: [1, num_classes, H, W]mask = mask_tensor.argmax(dim=1).squeeze(0) # 取最大概率的类别,[H, W]return mask.cpu().numpy().astype(np.uint8)def save_prediction(mask, output_path):# 将预测结果保存为 PNGresult = Image.fromarray(mask)result.save(output_path)# --- 预测流程 ---
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = load_model("checkpoints/best_model.pth", num_classes=4, device=device)# 定义预处理 transform (与训练时验证集相同)
transform = Compose([Resize(512, 512),Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),ToTensorV2(),
])# 预测单张影像
image_path = "path/to/your/new_image.tif"
output_path = "predictions/prediction.png"image_tensor = preprocess_image(image_path, transform).to(device)with torch.no_grad():output = model(image_tensor) # [1, 4, H, W]predicted_mask = postprocess_mask(output)save_prediction(predicted_mask, output_path)
print(f"Prediction saved to {output_path}")