当前位置: 首页 > backend >正文

CNN卷积神经网络之注意力机制CBAM(六)

CNN卷积神经网络之注意力机制CBAM(六)

文章目录

  • CNN卷积神经网络之注意力机制CBAM(六)
    • 1. CBAM 基本认知与结构图
    • 2. 通道注意力模块(Channel Attention)
      • 2.1 原理与流程图
      • 2.2 通道注意力 PyTorch 代码
    • 3. 空间注意力模块(Spatial Attention)
      • 3.1 原理与流程图
      • 3.2 空间注意力 PyTorch 代码
    • 4. 消融实验图表
      • 4.1 通道注意力有效性
      • 4.2 空间注意力叠加效果
      • 4.3 注意力顺序对比
      • 4.4 不同模型横向对比
    • 5. ResNet50 整合 CBAM 完整代码
    • 小结


下面为你完整提取文档中关于 CBAM(Convolutional Block Attention Module) 的全部内容,包括:

  • 核心思想与结构图
  • 通道注意力(Channel Attention)详解
  • 空间注意力(Spatial Attention)详解
  • 消融实验图表
  • 完整可运行的 PyTorch 代码(通道注意力、空间注意力、整合到 ResNet50)

1. CBAM 基本认知与结构图

图 1:CBAM 整体结构(通道注意力 + 空间注意力)
在这里插入图片描述

CBAM 是一个轻量级混合注意力模块,先后沿 通道维度空间维度 计算注意力权重,公式化描述如下:

F′=Mc(F)⊗F,F′′=Ms(F′)⊗F′.\begin{aligned} \mathbf{F'} &= \mathbf{M_c}(\mathbf{F}) \otimes \mathbf{F}, \\ \mathbf{F''} &= \mathbf{M_s}(\mathbf{F'}) \otimes \mathbf{F'}. \end{aligned} FF′′=Mc(F)F,=Ms(F)F.

  • Mc∈RC×1×1\mathbf{M_c} \in \mathbb{R}^{C\times 1\times 1}McRC×1×1:通道注意力图
  • Ms∈R1×H×W\mathbf{M_s} \in \mathbb{R}^{1\times H\times W}MsR1×H×W:空间注意力图
  • ⊗\otimes:逐元素乘法(广播机制)

2. 通道注意力模块(Channel Attention)

2.1 原理与流程图

图 2:通道注意力流程
在这里插入图片描述
  1. 对输入特征图 F∈RC×H×W\mathbf{F}\in\mathbb{R}^{C\times H\times W}FRC×H×W 分别做 全局平均池化全局最大池化,得到两个 1×1×C1\times 1\times C1×1×C 的向量。
  2. 将两个向量分别送入 共享的 MLP(两层全连接:先降维 C/rC/rC/r,再升维回 CCC)。
  3. 将 MLP 输出的两个特征向量逐元素相加,经 Sigmoid 得到通道注意力图 $mathbfMcmathbf{M_c}mathbfMc

数学表达:

Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))\mathbf{M_c}(\mathbf{F})=\sigma\left(\text{MLP}\bigl(\text{AvgPool}(\mathbf{F})\bigr)+\text{MLP}\bigl(\text{MaxPool}(\mathbf{F})\bigr)\right) Mc(F)=σ(MLP(AvgPool(F))+MLP(MaxPool(F)))

2.2 通道注意力 PyTorch 代码

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)

3. 空间注意力模块(Spatial Attention)

3.1 原理与流程图

图 3:空间注意力流程
在这里插入图片描述
  1. 通道维度 分别做 最大池化平均池化,得到两个 H×W×1H\times W\times 1H×W×1 的二维图。
  2. 将两个二维图在通道维度 拼接,形成 H×W×2H\times W\times 2H×W×2 的特征。
  3. 使用一个 7×77\times 77×7 卷积层将 222 通道压缩为 111 通道,再经 Sigmoid 得到空间注意力图 Ms\mathbf{M_s}Ms

数学表达:

Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))\mathbf{M_s}(\mathbf{F})=\sigma\left(f^{7\times7}\bigl([\text{AvgPool}(\mathbf{F});\text{MaxPool}(\mathbf{F})]\bigr)\right) Ms(F)=σ(f7×7([AvgPool(F);MaxPool(F)]))

3.2 空间注意力 PyTorch 代码

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), "kernel size must be 3 or 7"padding = kernel_size // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv(x)return self.sigmoid(x)

4. 消融实验图表

4.1 通道注意力有效性

图 4:仅通道注意力对比
在这里插入图片描述
  • 加入通道注意力后,所有模型均优于 baseline。

4.2 空间注意力叠加效果

图 5:叠加空间注意力
在这里插入图片描述
  • CBAM(通道+空间) 优于单独通道注意力;MaxPool + AvgPool 组合 最佳。

4.3 注意力顺序对比

图 6:通道先 / 空间先 对比
在这里插入图片描述
  • 通道 → 空间 的顺序(即 CBAM 默认顺序)效果最好。

4.4 不同模型横向对比

图 7:CBAM 在不同模型上的增益
在这里插入图片描述
  • 在 ResNet、WideResNet、ResNeXt 等主流网络上均带来一致提升。

5. ResNet50 整合 CBAM 完整代码

下面给出 ResNet50 + CBAM 的完整实现,可直接运行并导出 ONNX:

import torch
import torch.nn as nn
from torchvision.models.resnet import ResNet, Bottleneck, _resnet, ResNet50_Weights# -------------------- 1. 通道注意力 --------------------
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn self.sigmoid(out)# -------------------- 2. 空间注意力 --------------------
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), "kernel_size must be 3 or 7"padding = kernel_size // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv(x)return self.sigmoid(x)# -------------------- 3. 改造 Bottleneck,插入 CBAM --------------------
class BottleneckCBAM(Bottleneck):expansion = 4def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)# conv3 的输出通道即为 planes * expansionself.ca = ChannelAttention(self.conv3.out_channels)self.sa = SpatialAttention()def forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# CBAM 注意力out = self.ca(out) * out        # 通道加权out = self.sa(out) * out        # 空间加权if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out# -------------------- 4. 构建 ResNet50-CBAM --------------------
def resnet50_cbam(pretrained=False, **kwargs):"""返回带有 CBAM 的 ResNet50。pretrained=True 时加载 ImageNet 预训练权重(注意:CBAM 部分需要微调)。"""if pretrained:weights = ResNet50_Weights.DEFAULTelse:weights = Nonereturn _resnet(BottleneckCBAM,[3, 4, 6, 3],weights,progress=True,**kwargs)# -------------------- 5. 测试脚本 --------------------
if __name__ == "__main__":model = resnet50_cbam(pretrained=False, num_classes=10)  # 10 类示例model.eval()dummy = torch.randn(1, 3, 224, 224)# 打印网络结构(可选)# print(model)# 前向测试with torch.no_grad():out = model(dummy)print("输出 shape:", out.shape)          # 应为 [1, 10]# 导出 ONNXtorch.onnx.export(model, dummy, "resnet50_cbam.onnx",opset_version=11, input_names=["input"],output_names=["output"])print("已导出 resnet50_cbam.onnx")

小结

维度关键要点
模块顺序先通道注意力 → 后空间注意力
通道注意力AvgPool + MaxPool → Shared MLP → Sigmoid
空间注意力Channel-wise Avg & Max → Concat → 7×7 Conv → Sigmoid
即插即用可插入任何 CNN 架构,仅需几行代码
效果在 ImageNet、CIFAR 等多数据集、多模型上稳定提升
http://www.xdnf.cn/news/17013.html

相关文章:

  • 【android bluetooth 协议分析 01】【HCI 层介绍 30】【hci_event和le_meta_event如何上报到btu层】
  • uniapp Android App集成支付宝的扫码组件mPaaS
  • Linux 内存管理之 Rmap 反向映射(二)
  • Kafka-Eagle 安装
  • 江协科技STM32学习笔记1
  • AlexNet训练和测试FashionMNIST数据集
  • 什么是越权漏洞?如何验证。
  • c++介绍
  • cJSON库应用
  • Python高级编程与实践:Python装饰器深入解析与应用
  • 【数据结构初阶】--排序(三):冒泡排序,快速排序
  • BeeWorks私有化即时通讯,局域网办公安全可控
  • Python基础框架
  • 改进的BP神经网络算法用于预测温度值的变化
  • 剑指offer第2版:字符串
  • jenkins插件Active Choices的使用通过参数动态控制多选参数的选项
  • java web 未完成项目,本来想做个超市管理系统,前端技术还没学。前端是个简单的html。后端接口比较完善。
  • mq_timedsend系统调用及示例
  • 朴素贝叶斯(Naive Bayes)算法详解
  • 使用 ECharts 实现小区住户数量统计柱状图
  • 豆包新模型与 PromptPilot 实操体验测评,AI 辅助创作的新范式探索
  • 涨薪技术|Kubernetes(k8s)之Pod生命周期(上)
  • 山东省天地图API申请并加载到QGIS和ArcGIS Pro中
  • pyspark中的kafka的读和写案例操作
  • 面向对象编程基础:类的实例化与对象内存模型详解
  • Oracle 在线重定义
  • 【unitrix】 7.2 二进制位减法(bit_sub.rs)
  • MySQL偏门但基础的面试题集锦
  • MySql的两种安装方式
  • MySQL Router