当前位置: 首页 > news >正文

UNet网络 图像分割模型学习

UNet 由Ronneberger等人于2015年提出,专门针对医学图像分割任务,解决了早期卷积网络在小样本数据下的效率问题和细节丢失难题。

一 核心创新

1.1对称编码器-解码器结构

实现上下文信息高分辨率细节的双向融合

如图所示:编码器进行了4步(红框)到达了瓶颈层(紫框),每一步包含两次3x3卷积+ReLU并通过通过2x2最大池化下采样,到达瓶颈层后,解码器也进行了4步(绿框),使用了转置卷积上采样后与编码器对应层特征拼接(跳跃连接(灰色箭头))后再进行两次卷积。

可以看出解码器和编码器非常的对称,呈现一个U型,所以叫UNet。

其中:
编码器:通过池化逐渐扩大感受野。

解码器:逐步恢复空间分辨率,精确定位目标边界。

跳跃连接:将编码器特征与解码器特征拼接,融合多级信息解决深层网络定位精度下降的问题

1.2跳跃连接(Skip Connections)

解决深层卷积神经网络中空间信息丢失细节模糊的核心问题。

因为编码器下采样会丢失细节,而解码器上采样又难以完全恢复位置信息,所以使用跳跃链接来补偿细节。

1.2.1数学形式表达

设编码器第 $l$ 层输出为 $E_l \in \mathbb{R}^{H_l \times W_l \times C_l}$ , 解码器第 $l$ 层输入为 $D_l \in \mathbb{R}^{H_l \times W_l \times C_{l}'}$, 则跳跃连接操作:

$ D_l' = \text{Concat}(E_l, \text{UpSample}(D_{l+1})) $

Concat : 沿通道维度拼接(Channel-wise Concatenation)

UpSample:  转置卷积/双线性插值将解码器输出的分辨率提升至与编码器相同

1.2.2特征融合方法

编码器每层的输出须与解码器对应层上采样后的尺寸匹配,拼接后总通道数为两者之和。

(黑色圆圈)

# PyTorch代码示例:拼接编码器和解码器特征
def forward(self, decoder_feat, encoder_feat):# decoder_feat: [B, C1, H, W] # encoder_feat: [B, C2, H, W]merged = torch.cat([decoder_feat, encoder_feat], dim=1)  # 沿通道拼接return merged  # 结果维度:[B, C1+C2, H, W]

 1.3端到端精细分割(End-to-End Fine Segmentation)

在少量标注数据下仍能输出像素级预测

直接从原始输入图像生成像素级预测的模型设计范式,无需手动设计特征提取器或多阶段后处理。

1.3.1核心

全流程自动映射:输入 → 特征学习 → 高精度分割结果,中间过程由网络自动优化

细节敏感机制:通过多层次特征融合、边界增强模块等手段保证细粒度分割

无后处理输出:输出可直接使用,无需形态学后处理

1.3.2技术实现

编码器:通过卷积与池化逐层提取高层语义(形状、位置)

# 编码器层示例:每次下采样通道数翻倍
class Encoder(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),#卷积nn.BatchNorm2d(out_ch),#标准化(归一+线性变换)nn.ReLU(),#非线性激活nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(),nn.MaxPool2d(2)#最大值池化)def forward(self, x):return self.block(x)

解码器:上采样恢复分辨率 + 跳跃连接补充细节

# 解码器层示例:特征拼接后卷积
class Decoder(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)self.conv = nn.Sequential(nn.Conv2d(out_ch*2, out_ch, 3, padding=1), # 拼接后通道数翻倍nn.BatchNorm2d(out_ch),nn.ReLU())def forward(self, x, skip):x = self.up(x)x = torch.cat([x, skip], dim=1)  # 与编码器特征拼接return self.conv(x)

改良1: 注意力引导跳跃连接:通过空间注意力强化边缘区域(在跳跃连接前应用空间注意力,突出边缘信息)

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)self.sigmoid = nn.Sigmoid()def forward(self, x):avg = torch.mean(x, dim=1, keepdim=True)max_pool, _ = torch.max(x, dim=1, keepdim=True)concat = torch.cat([avg, max_pool], dim=1)  # 沿通道维度拼接均值和最大值mask = self.sigmoid(self.conv(concat))      # 生成空间注意力掩码return x * mask                             # 加权增强关键区域

改良2: 多尺度损失监督:在不同解码层注入辅助损失。

class MultiScaleLoss(nn.Module):def __init__(self, losses):super().__init__()self.losses = losses  # 各层对应的损失函数列表def forward(self, preds, target):total_loss = 0for pred, loss_fn in zip(preds, self.losses):# 将目标下采样至与当前预测同尺寸_, _, H, W = pred.shaperesized_target = F.interpolate(target, size=(H,W), mode='nearest')total_loss += loss_fn(pred, resized_target)return total_loss

适用性扩展:该范式可迁移至其他密集预测任务,如卫星影像分析、自动驾驶场景理解等。

二 与传统分割模型对比

模型优势局限性
FCN全卷积保留空间信息输出分辨率粗糙,跳跃连接简单
SegNet使用池化索引提升精度特征复用效率低
DeepLab空洞卷积扩大感受野小目标分割边缘模糊
UNet对称结构+密集跳跃连接,细节恢复原版对大尺度变化敏感

三 UNet的改良方法 

3.1跨尺度空洞卷积替换编码器的普通卷积层

在底层使用扩张率=1捕捉细节,高层使用d=3或5扩大感受野。

# 原编码器卷积块
self.encoder_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.ReLU(),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.ReLU()
)# 改进:跨尺度空洞卷积模块
self.encoder_conv = CrossScaleDilatedConv(in_ch, out_ch)

3.2融入密集块融合增强跳跃连接的特征传递

在编码器和解码器拼接前加入密集块

class ImprovedSkipConnection(nn.Module):def __init__(self, in_ch):super().__init__()self.dense_block = DenseBlock(num_layers=4, in_channels=in_ch)def forward(self, enc_feat, dec_feat):enc_processed = self.dense_block(enc_feat)  # 特征增强merged = torch.cat([enc_processed, dec_feat], dim=1)return merged# 在UNet解码器中应用
def forward(self, x):# ... 编码过程d4 = self.upconv4(d5)d4 = self.skip_conn4(e4, d4)  # 使用改进的跳跃连接d4 = self.decoder_conv4(d4)# ...

四 核心代码(未改良)

class UNet(nn.Module):def __init__(self, n_class=1):super().__init__()# 编码器self.enc1 = EncoderBlock(3, 64)self.enc2 = EncoderBlock(64, 128)self.enc3 = EncoderBlock(128, 256)self.enc4 = EncoderBlock(256, 512)self.bottleneck = EncoderBlock(512, 1024)# 解码器self.upconv4 = UpConv(1024, 512)self.dec4 = DecoderBlock(1024, 512)  # 输入1024因拼接self.upconv3 = UpConv(512, 256)self.dec3 = DecoderBlock(512, 256)self.upconv2 = UpConv(256, 128)self.dec2 = DecoderBlock(256, 128)self.upconv1 = UpConv(128, 64)self.dec1 = DecoderBlock(128, 64)self.final = nn.Conv2d(64, n_class, kernel_size=1)def forward(self, x):# 编码e1 = self.enc1(x)e2 = self.enc2(F.max_pool2d(e1, 2))e3 = self.enc3(F.max_pool2d(e2, 2))e4 = self.enc4(F.max_pool2d(e3, 2))bn = self.bottleneck(F.max_pool2d(e4, 2))# 解码d4 = self.dec4(self.upconv4(bn), e4)d3 = self.dec3(self.upconv3(d4), e3)d2 = self.dec2(self.upconv2(d3), e2)d1 = self.dec1(self.upconv1(d2), e1)return torch.sigmoid(self.final(d1))class EncoderBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU(),nn.Conv2d(out_ch, out_ch, 3, padding=1),nn.BatchNorm2d(out_ch),nn.ReLU())def forward(self, x):return self.conv(x)class UpConv(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)def forward(self, x):return self.up(x)class DecoderBlock(nn.Module):def __init__(self, in_ch, out_ch):super().__init__()self.conv = EncoderBlock(in_ch, out_ch)def forward(self, x, skip):x = torch.cat([x, skip], dim=1)  # 通道拼接return self.conv(x)

UNet凭借其优雅的对称结构和密集跳跃连接,成为医学图像分割的基准模型。通过集成跨尺度空洞卷积密集块融合等模块,可显著提升其对多尺度目标的适应性。

http://www.xdnf.cn/news/395137.html

相关文章:

  • 用户线程和守护线程
  • 机器学习极简入门:从基础概念到行业应用
  • 视频编码原理讲解一:VCL层和NAL层的讲解
  • 微服务架构-注册中心、配置中心:nacos入门
  • IPLOOK超轻量核心网,助力5G专网和MEC边缘快速落地
  • macOS 15 (Sequoia) 解除Gatekeeper限制
  • 可变参数模板
  • 微服务架构-限流、熔断
  • 小智AI机器人 - 代码框架梳理2
  • 【GPT入门】第38课 RAG评估指标概述
  • 什么是深度神经网络
  • AI自动化测试工具有哪些?
  • 优秀的流程图设计软件【留存】
  • stm32实战项目:无刷驱动
  • 深入浅出之STL源码分析7_模版实例化与全特化
  • 封装和分用(网络原理)
  • C# 方法(方法重载)
  • 查看YOLO版本的三种方法
  • 关于解决MySQL的常见问题
  • Linux基础开发工具一(yum/apt ,vim)
  • 滑动窗口——将x减到0的最小操作数
  • Python中的标识、相等性与别名:深入理解对象引用机制
  • Gartner 《2025大数据管理规划指南》学习心得
  • 【安装配置教程】ubuntu安装配置Kodbox
  • 【RP2350】香瓜树莓派RP2350之搭建开发环境(windows)
  • AI日报 - 2024年05月12日
  • redis数据结构-05 (LPUSH、RPUSH、LPOP、RPOP)
  • 第二十二节:图像金字塔-拉普拉斯金字塔
  • 深入浅出:Spring Boot 中 RestTemplate 的完整使用指南
  • AI Agent(9):企业应用场景