当前位置: 首页 > web >正文

08 - CoTAttention模块

论文《Contextual Transformer Networks for Visual Recognition》

1、 作用

Contextual Transformer (CoT) block 设计为视觉识别的一种新颖的 Transformer 风格模块。该设计充分利用输入键之间的上下文信息指导动态注意力矩阵的学习,从而加强视觉表示的能力。CoT block 首先通过 3x3 卷积对输入键进行上下文编码,得到输入的静态上下文表示。然后,将编码后的键与输入查询合并,通过两个连续的 1x1 卷积学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值,实现输入的动态上下文表示。最终将静态和动态上下文表示的融合作为输出。

2、机制

1、上下文编码

通过 3x3 卷积在所有邻居键内部空间上下文化每个键表示,捕获键之间的静态上下文信息。

2、动态注意力学习

基于查询和上下文化的键的连接,通过两个连续的 1x1 卷积产生注意力矩阵,这一过程自然地利用每个查询和所有键之间的相互关系进行自我注意力学习,并由静态上下文指导。

3、静态和动态上下文的融合

将静态上下文和通过上下文化自注意力得到的动态上下文结合,作为 CoT block 的最终输出。

3、 独特优势

1、上下文感知

CoT 通过在自注意力学习中探索输入键之间的富上下文信息,使模型能够更准确地捕获视觉内容的细微差异。

2、动静态上下文的统一

CoT 设计巧妙地将上下文挖掘与自注意力学习统一到单一架构中,既利用键之间的静态关系又探索动态特征交互,提升了模型的表达能力。

3、灵活替换与优化

CoT block 可以直接替换现有 ResNet 架构中的标准卷积,不增加参数和 FLOP 预算的情况下实现转换为 Transformer 风格的骨干网络(CoTNet),通过广泛的实验验证了其在多种应用(如图像识别、目标检测和实例分割)中的优越性。

4、代码

# 导入必要的PyTorch模块
import torch
from torch import nn
from torch.nn import functional as Fclass CoTAttention(nn.Module):# 初始化CoT注意力模块def __init__(self, dim=512, kernel_size=3):super().__init__()self.dim = dim  # 输入的通道数self.kernel_size = kernel_size  # 卷积核大小# 定义用于键(key)的卷积层,包括一个分组卷积,BatchNorm和ReLU激活self.key_embed = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=4, bias=False),nn.BatchNorm2d(dim),nn.ReLU())# 定义用于值(value)的卷积层,包括一个1x1卷积和BatchNormself.value_embed = nn.Sequential(nn.Conv2d(dim, dim, 1, bias=False),nn.BatchNorm2d(dim))# 缩小因子,用于降低注意力嵌入的维度factor = 4# 定义注意力嵌入层,由两个卷积层、一个BatchNorm层和ReLU激活组成self.attention_embed = nn.Sequential(nn.Conv2d(2*dim, 2*dim//factor, 1, bias=False),nn.BatchNorm2d(2*dim//factor),nn.ReLU(),nn.Conv2d(2*dim//factor, kernel_size*kernel_size*dim, 1))def forward(self, x):# 前向传播函数bs, c, h, w = x.shape  # 输入特征的尺寸k1 = self.key_embed(x)  # 生成键的静态表示v = self.value_embed(x).view(bs, c, -1)  # 生成值的表示并调整形状y = torch.cat([k1, x], dim=1)  # 将键的静态表示和原始输入连接att = self.attention_embed(y)  # 生成动态注意力权重att = att.reshape(bs, c, self.kernel_size*self.kernel_size, h, w)att = att.mean(2, keepdim=False).view(bs, c, -1)  # 计算注意力权重的均值并调整形状k2 = F.softmax(att, dim=-1) * v  # 应用注意力权重到值上k2 = k2.view(bs, c, h, w)  # 调整形状以匹配输出return k1 + k2  # 返回键的静态和动态表示的总和# 实例化CoTAttention模块并测试
if __name__ == '__main__':block = CoTAttention(64)  # 创建一个输入通道数为64的CoTAttention实例input = torch.rand(1, 64, 64, 64)  # 创建一个随机输入output = block(input)  # 通过CoTAttention模块处理输入print(output.shape)  # 打印输入和输出的尺寸
http://www.xdnf.cn/news/14113.html

相关文章:

  • 使用Claude Desktop快速体验MCP servers!
  • 短剧热浪,席卷海内外。
  • Rust编写Shop管理系统
  • 长春光博会 | 麒麟信安:构建工业数字化安全基座,赋能智能制造转型升级
  • 深入剖析Redis高性能的原因,IO多路复用模型,Redis数据迁移,分布式锁实现
  • Python数据可视化:Seaborn入门与实践
  • LeetCode 744.寻找比目标字母大的最小字母
  • 【动手学深度学习】3.5. 图像分类数据集
  • 3D模型格式转换HOOPS Exchange与工程设计软件自带转换器对比分析
  • 力扣-322.零钱兑换
  • 最新四六级写作好词好句锦囊(持续更新中)
  • 【VS2022 配置 ACADOS环境】
  • Java集合 - ArrayList底层源码解析
  • 精益数据分析(102/126):SaaS用户流失率优化与OfficeDrop的转型启示
  • Trae国内版Builder模式VS Chat模式
  • 1.3、SDH光接口类型
  • powerShell调用cmd
  • Epigenetics ATAC-seq助力解析炎症性细胞因子IL-1刺激引起的动态染色质可及性变化
  • Marketing Agent实施成本全解析:价格构成、影响因素与技术选型建议
  • vector的用法
  • Web网页端即时通讯源码/IM聊天源码RainbowChat-Web
  • 一阶拟线性偏微分方程光滑解的存在性与最大初始振幅分析
  • Rocky Linux 9 系统安装配置图解教程并做简单配置
  • Node.js下载安装及环境配置教程
  • IEEE-745标准4字节16进制转浮点
  • 【VUE3】基于Vue3和Element Plus的递归组件实现多级导航栏
  • 社会应用融智学的人力资源模式:潜能开发评估;认知基建资产
  • 【为什么InnoDB用B+树?从存储结构到索引设计深度解析】
  • 车载以太网-switch
  • 无人机噪音处理模块技术分析