08 - CoTAttention模块
论文《Contextual Transformer Networks for Visual Recognition》
1、 作用
Contextual Transformer (CoT) block 设计为视觉识别的一种新颖的 Transformer 风格模块。该设计充分利用输入键之间的上下文信息指导动态注意力矩阵的学习,从而加强视觉表示的能力。CoT block 首先通过 3x3 卷积对输入键进行上下文编码,得到输入的静态上下文表示。然后,将编码后的键与输入查询合并,通过两个连续的 1x1 卷积学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值,实现输入的动态上下文表示。最终将静态和动态上下文表示的融合作为输出。
2、机制
1、上下文编码:
通过 3x3 卷积在所有邻居键内部空间上下文化每个键表示,捕获键之间的静态上下文信息。
2、动态注意力学习:
基于查询和上下文化的键的连接,通过两个连续的 1x1 卷积产生注意力矩阵,这一过程自然地利用每个查询和所有键之间的相互关系进行自我注意力学习,并由静态上下文指导。
3、静态和动态上下文的融合:
将静态上下文和通过上下文化自注意力得到的动态上下文结合,作为 CoT block 的最终输出。
3、 独特优势
1、上下文感知:
CoT 通过在自注意力学习中探索输入键之间的富上下文信息,使模型能够更准确地捕获视觉内容的细微差异。
2、动静态上下文的统一:
CoT 设计巧妙地将上下文挖掘与自注意力学习统一到单一架构中,既利用键之间的静态关系又探索动态特征交互,提升了模型的表达能力。
3、灵活替换与优化:
CoT block 可以直接替换现有 ResNet 架构中的标准卷积,不增加参数和 FLOP 预算的情况下实现转换为 Transformer 风格的骨干网络(CoTNet),通过广泛的实验验证了其在多种应用(如图像识别、目标检测和实例分割)中的优越性。
4、代码
# 导入必要的PyTorch模块
import torch
from torch import nn
from torch.nn import functional as Fclass CoTAttention(nn.Module):# 初始化CoT注意力模块def __init__(self, dim=512, kernel_size=3):super().__init__()self.dim = dim # 输入的通道数self.kernel_size = kernel_size # 卷积核大小# 定义用于键(key)的卷积层,包括一个分组卷积,BatchNorm和ReLU激活self.key_embed = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=4, bias=False),nn.BatchNorm2d(dim),nn.ReLU())# 定义用于值(value)的卷积层,包括一个1x1卷积和BatchNormself.value_embed = nn.Sequential(nn.Conv2d(dim, dim, 1, bias=False),nn.BatchNorm2d(dim))# 缩小因子,用于降低注意力嵌入的维度factor = 4# 定义注意力嵌入层,由两个卷积层、一个BatchNorm层和ReLU激活组成self.attention_embed = nn.Sequential(nn.Conv2d(2*dim, 2*dim//factor, 1, bias=False),nn.BatchNorm2d(2*dim//factor),nn.ReLU(),nn.Conv2d(2*dim//factor, kernel_size*kernel_size*dim, 1))def forward(self, x):# 前向传播函数bs, c, h, w = x.shape # 输入特征的尺寸k1 = self.key_embed(x) # 生成键的静态表示v = self.value_embed(x).view(bs, c, -1) # 生成值的表示并调整形状y = torch.cat([k1, x], dim=1) # 将键的静态表示和原始输入连接att = self.attention_embed(y) # 生成动态注意力权重att = att.reshape(bs, c, self.kernel_size*self.kernel_size, h, w)att = att.mean(2, keepdim=False).view(bs, c, -1) # 计算注意力权重的均值并调整形状k2 = F.softmax(att, dim=-1) * v # 应用注意力权重到值上k2 = k2.view(bs, c, h, w) # 调整形状以匹配输出return k1 + k2 # 返回键的静态和动态表示的总和# 实例化CoTAttention模块并测试
if __name__ == '__main__':block = CoTAttention(64) # 创建一个输入通道数为64的CoTAttention实例input = torch.rand(1, 64, 64, 64) # 创建一个随机输入output = block(input) # 通过CoTAttention模块处理输入print(output.shape) # 打印输入和输出的尺寸