当前位置: 首页 > web >正文

注意力机制模块代码

  • 被广泛推荐使用:SE、ECA、Coordinate Attention(CA)——轻量、易用且效果稳定。

  • 仍可用但要考虑计算成本:BAM、GCNet、SKNet。

  • 一般不建议首选(算是“过时”或逐步淘汰):Non-local、DANet,尤其在大规模、3D医学图像中不易使用。

 

SE模块(Squeeze-and-Excitation,通道注意力)

import torch
import torch.nn as nnclass SEBlock(nn.Module):def __init__(self, channel, reduction=16):super(SEBlock, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)  # [B, C]y = self.fc(y).view(b, c, 1, 1)  # 通道注意力权重return x * y.expand_as(x)

适合插入位置:

  • 卷积层后的通道注意力模块,一般放在每个卷积块或残差块的末尾;

  • 编码器的每个阶段卷积输出后,对通道进行重标定。

作用:

  • 通过“压缩”(Squeeze,全局平均池化)和“激励”(Excitation,两个全连接层)生成通道权重;

  • 提升模型对关键通道的响应能力,抑制无关通道;

  • 结构简单,参数少,容易插入。

CBAM模块(Convolutional Block Attention Module,包含通道+空间注意力)

import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_planes, reduction=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_planes, in_planes // reduction, bias=False),nn.ReLU(),nn.Linear(in_planes // reduction, in_planes, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, _, _ = x.size()avg_out = self.fc(self.avg_pool(x).view(b, c))max_out = self.fc(self.max_pool(x).view(b, c))out = avg_out + max_outout = self.sigmoid(out).view(b, c, 1, 1)return x * out.expand_as(x)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()padding = (kernel_size - 1) // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)  # 通道维平均池化max_out, _ = torch.max(x, dim=1, keepdim=True)  # 通道维最大池化out = torch.cat([avg_out, max_out], dim=1)  # 2通道输入out = self.conv(out)out = self.sigmoid(out)return x * outclass CBAM(nn.Module):def __init__(self, in_planes, reduction=16, kernel_size=7):super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_planes, reduction)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):out = self.channel_attention(x)out = self.spatial_attention(out)return out
  • 适合插入位置

    • 卷积层之后,作为特征增强模块;

    • 可插入至编码器或解码器的每个卷积块中(如UNet的每个Down或Up Block后);

    • 用于桥接阶段(编码器与解码器中间),加强高层语义表达。

  • 插入原因

    • 串联通道注意力与空间注意力,分别从“通道”和“空间位置”两个维度强化信息;

    • 能抑制冗余背景区域,突出出血区域的关键通道与空间位置;

    • 模块轻量、效果明显、易于嵌入到任意CNN结构。

Non-Local Attention(空间全局自注意力模块)

import torch
import torch.nn as nnclass NonLocalBlock(nn.Module):def __init__(self, in_channels, inter_channels=None):super(NonLocalBlock, self).__init__()self.in_channels = in_channelsself.inter_channels = inter_channels if inter_channels else in_channels // 2if self.inter_channels == 0:self.inter_channels = 1self.g = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.theta = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.phi = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)self.W = nn.Conv2d(self.inter_channels, in_channels, kernel_size=1)self.bn = nn.BatchNorm2d(in_channels)def forward(self, x):batch_size, C, H, W = x.size()g_x = self.g(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]g_x = g_x.permute(0, 2, 1)  # [B, H*W, C']theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]theta_x = theta_x.permute(0, 2, 1)  # [B, H*W, C']phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)  # [B, C', H*W]f = torch.matmul(theta_x, phi_x)  # [B, H*W, H*W]f_div_C = nn.functional.softmax(f, dim=-1)y = torch.matmul(f_div_C, g_x)  # [B, H*W, C']y = y.permute(0, 2, 1).contiguous()  # [B, C', H*W]y = y.view(batch_size, self.inter_channels, H, W)W_y = self.W(y)W_y = self.bn(W_y)z = W_y + x  # 残差连接return z

适合插入位置:

  • 中间层特征图大小适中时,例如编码器中后期特征层;

  • 需要捕获远距离依赖信息的地方

作用:

  • 建立空间上任意两个位置间的关系,用全局加权方式计算注意力;

  • 能捕获长距离的上下文依赖,强化特征表达;

  • 对脑出血图像分割帮助捕捉大范围病灶相关信息。

Transformer注意力机制中常用的多头自注意力模块(简版)

import torch
import torch.nn as nnclass MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads):super(MultiHeadSelfAttention, self).__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsassert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):# x shape: [batch_size, seq_len, embed_dim]batch_size, seq_len, embed_dim = x.size()qkv = self.qkv_proj(x)  # [B, S, 3*E]qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, seq_len, head_dim]q, k, v = qkv[0], qkv[1], qkv[2]  # 各 [B, heads, seq_len, head_dim]attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)  # scaled dot productattn_weights = nn.functional.softmax(attn_scores, dim=-1)attn_output = torch.matmul(attn_weights, v)  # [B, heads, seq_len, head_dim]attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)output = self.out_proj(attn_output)return output

BAM 模块(Bottleneck Attention Module 通道+空间)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ChannelGate(nn.Module):def __init__(self, gate_channels, reduction_ratio=16):super(ChannelGate, self).__init__()self.mlp = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(gate_channels, gate_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(gate_channels // reduction_ratio, gate_channels, 1, bias=False))def forward(self, x):y = self.mlp(x)return yclass SpatialGate(nn.Module):def __init__(self, kernel_size=7):super(SpatialGate, self).__init__()self.spatial = nn.Sequential(nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2, bias=False),nn.BatchNorm2d(1))def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)y = avg_out + max_outy = self.spatial(y)return yclass BAM(nn.Module):def __init__(self, gate_channels, reduction_ratio=16, kernel_size=7):super(BAM, self).__init__()self.channel_gate = ChannelGate(gate_channels, reduction_ratio)self.spatial_gate = SpatialGate(kernel_size)self.sigmoid = nn.Sigmoid()def forward(self, x):chn_att = self.channel_gate(x)sp_att = self.spatial_gate(x)att = self.sigmoid(chn_att + sp_att)return x * att

适合插入位置:

  • 主干网络中间层的残差块后面,比如ResNet的残差块后;

  • UNet的编码器和解码器中间特征融合后

  • 跳跃连接处,对通道和空间信息进行联合加权。

原因:

  • BAM兼顾空间和通道注意力,可以帮助模型突出重要的空间区域和关键通道特征;

  • 在深层网络特征较丰富后使用,能更好地强化重要信息,抑制无关信息;

  • 脑出血CT图像中病灶局部显著,BAM可帮助定位病灶区域。

ECA模块(Efficient Channel Attention 通道)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ECALayer(nn.Module):def __init__(self, channel, k_size=3):super(ECALayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size-1)//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# x: [B, C, H, W]y = self.avg_pool(x)  # [B, C, 1, 1]y = y.squeeze(-1).transpose(-1, -2)  # [B, 1, C]y = self.conv(y)  # [B, 1, C]y = self.sigmoid(y).transpose(-1, -2).unsqueeze(-1)  # [B, C, 1, 1]return x * y.expand_as(x)

适合插入位置:

  • 卷积块的输出后,例如卷积层组的末尾;

  • UNet编码器各阶段卷积输出后,在通道维度做轻量级的通道权重调整;

  • 轻量化网络中,用于替代复杂的SE模块

原因:

  • ECA关注通道关系但没有过多参数,能快速提升通道特征质量;

  • 脑出血CT中不同通道可能对不同病灶结构敏感,ECA能动态调整通道权重;

  • 插入在特征图尺寸未大幅缩小前,效果较好。

Coordinate Attention(坐标注意力,CA)

import torch
import torch.nn as nn
import torch.nn.functional as Fclass CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))  # 保持宽=1self.pool_w = nn.AdaptiveAvgPool2d((1, None))  # 保持高=1mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = nn.ReLU()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()# 高维方向池化x_h = self.pool_h(x)  # [N, C, H, 1]x_w = self.pool_w(x).permute(0, 1, 3, 2)  # [N, C, 1, W] → [N, C, W, 1]y = torch.cat([x_h, x_w], dim=2)  # [N, C, H+W, 1]y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_h * a_wreturn out

适合插入位置:

  • 编码器特征提取后,特别是空间维度还较大时;

  • 解码器中尺度融合后,增强空间定位能力;

  • 跳跃连接后,加强特征的空间位置信息。

原因:

  • CA不仅捕获通道关系,还能明确空间的长宽坐标信息,非常适合需要精确定位病灶的任务;

  • 对脑出血的CT分割来说,准确捕捉空间位置很关键;

  • 能提高模型对边缘和细节的感知能力。

SKNet模块(Selective Kernel  )

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SKConv(nn.Module):def __init__(self, features, M=2, G=32, r=16, L=32):super(SKConv, self).__init__()d = max(int(features / r), L)self.M = M  # 分支数量self.features = featuresself.convs = nn.ModuleList()for i in range(M):self.convs.append(nn.Sequential(nn.Conv2d(features, features, kernel_size=3+i*2, stride=1, padding=1+i, groups=G, bias=False),nn.BatchNorm2d(features),nn.ReLU(inplace=True)))self.fc = nn.Linear(features, d)self.fcs = nn.ModuleList()for i in range(M):self.fcs.append(nn.Linear(d, features))self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.size(0)feats = []for conv in self.convs:feats.append(conv(x))feats = torch.stack(feats, dim=1)  # [B, M, C, H, W]U = torch.sum(feats, dim=1)  # 聚合 [B, C, H, W]s = U.mean(-1).mean(-1)  # 全局平均池化 [B, C]z = self.fc(s)  # [B, d]attention_vectors = []for fc in self.fcs:attention_vectors.append(fc(z).unsqueeze(1))  # [B, 1, C]attention_vectors = torch.cat(attention_vectors, dim=1)  # [B, M, C]attention_vectors = self.softmax(attention_vectors)  # 权重归一化attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)  # [B, M, C, 1, 1]out = (feats * attention_vectors).sum(dim=1)  # 加权求和return out

适合插入位置:

  • 卷积层之间替代标准卷积模块,作为多尺度特征提取模块;

  • 编码器的中间阶段,通过多尺度卷积动态选择感受野;

  • 增强不同尺度脑出血区域的检测

原因:

  • SKNet通过动态融合不同卷积核大小的特征,适应不同尺寸的目标区域;

  • 对脑出血CT图像中大小不一的出血块都能有效响应;

  • 插入到卷积阶段,能更好捕获多尺度上下文信息。

http://www.xdnf.cn/news/9349.html

相关文章:

  • Oracle 12c新增的数字转换验证VALIDATE_CONVERSION函数
  • rabbitmq的高级特性
  • 理解 Kubernetes 的架构与控制平面组件运行机制
  • WebSocket学习总结
  • Python Day34 学习
  • 深度学习能取代机器学习吗?
  • 庄家抬轿指标,通达信炒股软件副图指标公式,指标使用图文教程
  • Linux Ubuntu24.04配置安装MySQL8.4.5高可用集群主从复制!
  • AI 编程如何让你轻松采集网站数据?
  • GitHub 趋势日报 (2025年05月26日)
  • 体现物联网环境下安全防护的紧迫性 :物联网环境下的个人信息安全:隐忧与防护之道
  • 【Spring AI】Spring AI 1.0.0-M7、M8更新至1.0.0版本兼容的所需修改要点
  • 【ARM】如何通过ARMDS的Map文件查看堆栈调用情况
  • 【MAP容器姓名成绩输入查询修改删除】2022-2-4
  • 5 WPF中的Page页面的使用
  • 2.3 TypeScript 非空断言操作符(后缀 !)详解
  • C++中回调函数详解
  • javaEE1
  • 【JavaEE】-- 文件操作和IO
  • FART 自动化脱壳框架一些 bug 修复记录
  • Python学习(1) ----- Python的文件读取和写入
  • 芝麻糊SSVIP2.0.5.7 | 自动收取能量 小游戏任务
  • CSS 中的transform详解
  • OptiStruct结构分析与工程应用:NVH外声场分析
  • AStar低代码平台-脚本调用C#方法
  • 【MySQL】2-MySQL索引P2-执行计划
  • 2025蓝桥杯WP
  • C++学习-入门到精通【9】面向对象编程:继承
  • 青少年编程与数学 02-020 C#程序设计基础 06课题、运算符和表达式
  • 内容中台的AI驱动是什么?