当前位置: 首页 > java >正文

【AI模型学习】上/下采样

文章目录

  • 分割中的上/下采样
  • 下采样
    • SegFormer和PVT(使用卷积)
    • Swin-Unet(使用 Patch Merging)
  • 上采样
    • SegFormer(interpolate)
    • Swin-Unet(Patch Expanding)
    • 逐级interpolate的方式
    • 反卷的方式


基于Transformer架构的图像分割模型(如 SegFormer、Swin-Unet)中,上采样和下采样结构几乎是标准配置。

分割中的上/下采样

为什么需要下采样?

  1. 提取高层语义特征
    Transformer擅长全局建模,结合下采样可以:降低分辨率;聚焦于更宽范围的上下文。

  2. 减少计算成本
    原始输入图像太大,直接送入多层Transformer(特别是多头注意力)会导致计算量和显存爆炸。

为什么要上采样?

  1. 恢复空间分辨率
    Segmentation任务最终要输出与输入图像同样大小的分割mask;

  2. 细粒度定位
    但如果没有上采样、跳跃连接或融合,容易失去细节;所以上采样常结合UNet-like结构来补偿细节损失。

模型下采样方式上采样方式
SETRViT backbone patchify多层反卷积上采样
SegFormerMLP Mixer + 4阶段卷积下采样多层插值 + FFN
Swin-UnetSwin Transformer 下采样Patch expanding + skip连接

下采样

SegFormer和PVT(使用卷积)

# 输入:img.shape = [B, 3, 512, 512]# Stage 1
x1 = Conv2d(3, 32, kernel_size=7, stride=4, padding=3)(img)    # → [B, 32, 128, 128]
x1 = x1.flatten(2).transpose(1, 2)                             # → [B, 16384, 32]
x1 = TransformerBlock(x1)# Stage 2
x2 = Conv2d(32, 64, kernel_size=3, stride=2, padding=1)(x1_reshaped)  # → [B, 64, 64, 64]
x2 = x2.flatten(2).transpose(1, 2)                                    # → [B, 4096, 64]
x2 = TransformerBlock(x2)# 后面还有 Stage3、Stage4 类似

Shape 演化:

Stage1: [B, 128×128=16384, 32]
Stage2: [B, 64×64=4096, 64]
Stage3: [B, 32×32=1024, 160]
Stage4: [B, 16×16=256, 256]

Swin-Unet(使用 Patch Merging)

# 初始 patch embedding(patch_size=4)
x = Conv2d(3, 96, kernel_size=4, stride=4)(img)      # [B, 96, 128, 128]
x = x.flatten(2).transpose(1, 2)                     # → [B, 16384, 96]# Stage 1
x = SwinBlock(x)                                     # [B, 16384, 96]
x = PatchMerging(x)                                  # → [B, 4096, 192]# Stage 2
x = SwinBlock(x)                                     # [B, 4096, 192]
x = PatchMerging(x)                                  # → [B, 1024, 384]# Stage 3
x = SwinBlock(x)
x = PatchMerging(x)                                  # → [B, 256, 768]

Shape 演化:

Stage0: [B, 128×128=16384, 96]
Stage1: [B, 64×64=4096, 192]
Stage2: [B, 32×32=1024, 384]
Stage3: [B, 16×16=256, 768]

过程不难,只是不好描述,可以看相关教程,这里就把代码贴出来

class PatchMerging(nn.Module):def __init__(self, in_dim):super().__init__()self.reduction = nn.Linear(in_dim * 4, in_dim * 2)def forward(self, x, H, W):# x: [B, H*W, C] → [B, H, W, C]x = x.view(B, H, W, C)# 拆分四个方向的 tokenx0 = x[:, 0::2, 0::2, :]  # top-leftx1 = x[:, 1::2, 0::2, :]  # bottom-leftx2 = x[:, 0::2, 1::2, :]  # top-rightx3 = x[:, 1::2, 1::2, :]  # bottom-rightx = torch.cat([x0, x1, x2, x3], dim=-1)  # → [B, H/2, W/2, 4C]x = x.view(B, -1, 4 * C)                 # → [B, H/2*W/2, 4C]x = self.reduction(x)                   # → [B, H/2*W/2, 2C]return x

上采样

SegFormer(interpolate)

在这里插入图片描述

def forward(self, x1, x2, x3, x4):  # 输入来自4个Stage:# x1: [B, 128*128, 32]# x2: [B, 64*64,   64]# x3: [B, 32*32,  160]# x4: [B, 16*16,  256]B = x1.shape[0]# === 1. Linear Projection:通道都投影为 256 ===_x1 = self.linear1(x1).permute(0, 2, 1).reshape(B, 256, 128, 128)  # [B, 256, 128, 128]_x2 = self.linear2(x2).permute(0, 2, 1).reshape(B, 256,  64,  64)  # [B, 256,  64,  64]_x3 = self.linear3(x3).permute(0, 2, 1).reshape(B, 256,  32,  32)  # [B, 256,  32,  32]_x4 = self.linear4(x4).permute(0, 2, 1).reshape(B, 256,  16,  16)  # [B, 256,  16,  16]# === 2. 上采样到统一大小 ===_x2 = F.interpolate(_x2, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]_x3 = F.interpolate(_x3, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]_x4 = F.interpolate(_x4, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]# === 3. 拼接所有层 ===fused = torch.cat([_x1, _x2, _x3, _x4], dim=1)  # [B, 4*256=1024, 128, 128]# === 4. 1x1卷积融合通道数 ===out = self.fuse_conv(fused)  # [B, 256, 128, 128]return out

Swin-Unet(Patch Expanding)

看图也能看出来,十分经典的U-Net结构。

在这里插入图片描述

在上采样阶段
输入:

一个高语义 token,维度为 [4C],是上一步 Patch Merging 得到的。

  1. Linear 映射
    [4C] 投影为 [C] × 4,也就是还原为 2×2 patch 每格的 C 维向量。

  2. reshape → [H, W, 2, 2, C] → [2H, 2W, C]
    把这 4 个 token 安排到一个新的空间位置(上采样 ×2)。

  3. 最终输出为:

    Token 数量 × 4 , 通道数 ÷ 2 \text{Token 数量} \times 4,\quad \text{通道数} \div 2 Token 数量×4通道数÷2

class PatchExpanding(nn.Module):def __init__(self, in_dim, expand_ratio=2):super().__init__()# Linear: [B, H*W, in_dim] → [B, H*W, out_dim = (expand_ratio^2) * out_channels]# 例如:in_dim = 512,expand_ratio = 2 → 输出 4×C = 1024self.linear = nn.Linear(in_dim, in_dim // 2 * expand_ratio**2)self.expand_ratio = expand_ratiodef forward(self, x, H, W):# x: [B, H*W, C]B, N, C = x.shapeR = self.expand_ratio  # 通常为 2# 线性投影:C → 4 * (C/2),也就是 [B, H*W, 4*C'],每个 token 展开为 2×2 的 patchx = self.linear(x)                      # [B, H*W, 4*C'] = [B, H*W, R*R*(C//2)]#  reshape 成图像形式,带有 2×2 子结构 → [B, H, W, R, R, C']x = x.view(B, H, W, R, R, C // 2)       # [B, H, W, 2, 2, C//2]#  调整顺序,将 2×2 子结构移入空间维度 → [B, H*2, W*2, C//2]x = x.permute(0, 1, 3, 2, 4, 5)         # [B, H, 2, W, 2, C//2]x = x.reshape(B, H * R, W * R, C // 2)  # [B, 2H, 2W, C//2]#  flatten 成 token 序列形式(可再送入 Transformer)→ [B, 4*H*W, C//2]x = x.view(B, -1, C // 2)               # [B, 4*H*W, C//2]return x

逐级interpolate的方式

  1. 输入来自编码器 4 个 stage

    • x4:[16×16, 512] ← 最深层
    • x3:[32×32, 320]
    • x2:[64×64, 128]
    • x1:[128×128, 64] ← 最浅层
  2. 通道统一:
    每个特征图先通过 1×1 卷积或 Linear 映射,统一成相同维度(如全部 → 256 或 512)

  3. 上采样与融合(逐级):

    f4 = Conv(x4)                          # [16×16]
    f3 = F.interpolate(f4, scale=2) + Conv(x3)  # → [32×32]
    f2 = F.interpolate(f3, scale=2) + Conv(x2)  # → [64×64]
    f1 = F.interpolate(f2, scale=2) + Conv(x1)  # → [128×128]
    

反卷的方式

很经典的设计,不必过多介绍。

http://www.xdnf.cn/news/8241.html

相关文章:

  • 【SpringBoot实战指南】使用 Spring Cache
  • 5.22 打卡
  • 生存资料的多因素分析,如果满 足等比例风险假定, 采用Cox回归; 如果不满足等比例风险假定,则考虑采用 非等比例Cox回归分析研究预后因素的影响
  • Java版本的VPN(wlcn)
  • 我的世界模组开发——物理学(1)
  • PiliPlus 非常好用的开源软件第三方B站哔哩哔哩 v1.1.3.35
  • Vue 3.0中异步组件defineAsyncComponent
  • JC/T 2387-2024 改性聚苯乙烯泡沫(EPS)复合装饰制品检测
  • 从零基础到最佳实践:Vue.js 系列(10/10):《实战项目——从零到上线》
  • 2025淘宝最新DSR评分计算方式
  • Python RSA加解密脚本
  • AI相关的笔记
  • (第93天)OGG 搭建 Oracle 19C 数据同步 - 远程部署
  • 博奥龙Nanoantibody系列IP专用抗体
  • ubuntu安装blender并配置应用程序图标
  • HW云RDS性能压测
  • C++中的菱形继承问题
  • 5.22学习日记 ssh远程加密、非对称加密、对称加密与中间人攻击的原理
  • Linux安装SRILM
  • 【Android开发——Activity简述】
  • Femap许可证兼容性问题
  • 同城上门预约服务系统案例分享,上门服务到家系统都有什么功能?这个功能,很重要!
  • 科学养生指南:解锁健康生活密码
  • 一个简易的图片与文件从同一个入口上传
  • 【数据结构】链式二叉树
  • 物理定律的数学结构基础及AI推理
  • [欠拟合过拟合]机器学习-part10
  • Java:希尔排序
  • 15.集合框架的学习
  • Unity基础学习(六)Mono中的重要内容(2)协同程序