当前位置: 首页 > news >正文

09 - TripletAttention模块

论文《Rotate to Attend: Convolutional Triplet Attention Module》

1、作用

Triplet Attention是一种新颖的注意力机制,它通过捕获跨维度交互,利用三分支结构来计算注意力权重。对于输入张量,Triplet Attention通过旋转操作建立维度间的依赖关系,随后通过残差变换对信道和空间信息进行编码,实现了几乎不增加计算成本的情况下,有效增强视觉表征的能力。

2、机制

1、三分支结构

Triplet Attention包含三个分支,每个分支负责捕获输入的空间维度H或W与信道维度C之间的交互特征。

2、跨维度交互

通过在每个分支中对输入张量进行排列(permute)操作,并通过Z-pool和k×k的卷积层处理,以捕获跨维度的交互特征。

3、注意力权重的生成

利用sigmoid激活层生成注意力权重,并应用于排列后的输入张量,然后将其排列回原始输入形状。

3、 独特优势

1、跨维度交互

Triplet Attention通过捕获输入张量的跨维度交互,提供了丰富的判别特征表征,较之前的注意力机制(如SENet、CBAM等)能够更有效地增强网络的性能。

2、几乎无计算成本增加

相比于传统的注意力机制,Triplet Attention在提升网络性能的同时,几乎不增加额外的计算成本和参数数量,使得它可以轻松地集成到经典的骨干网络中。

3、无需降维

与其他注意力机制不同,Triplet Attention不进行维度降低处理,这避免了因降维可能导致的信息丢失,保证了信道与权重间的直接对应关系。

总的来说,Triplet Attention通过其独特的三分支结构和跨维度交互机制,在提高模型性能的同时,保持了计算效率,显示了其在各种视觉任务中的应用潜力。

4、代码

import torch
import torch.nn as nn# 定义一个基本的卷积模块,包括卷积、批归一化和ReLU激活
class BasicConv(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):super(BasicConv, self).__init__()self.out_channels = out_planes# 定义卷积层self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)# 条件性地添加批归一化层self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None# 条件性地添加ReLU激活函数self.relu = nn.ReLU() if relu else Nonedef forward(self, x):x = self.conv(x)  # 应用卷积if self.bn is not None:x = self.bn(x)  # 应用批归一化if self.relu is not None:x = self.relu(x)  # 应用ReLUreturn x# 定义ZPool模块,结合最大池化和平均池化结果
class ZPool(nn.Module):def forward(self, x):# 结合最大值和平均值return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)# 定义注意力门,用于根据输入特征生成注意力权重
class AttentionGate(nn.Module):def __init__(self):super(AttentionGate, self).__init__()kernel_size = 7  # 设定卷积核大小self.compress = ZPool()  # 使用ZPool模块self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)  # 通过卷积调整通道数def forward(self, x):x_compress = self.compress(x)  # 应用ZPoolx_out = self.conv(x_compress)  # 通过卷积生成注意力权重scale = torch.sigmoid_(x_out)  # 应用Sigmoid激活return x * scale  # 将注意力权重乘以原始特征# 定义TripletAttention模块,结合了三种不同方向的注意力门
class TripletAttention(nn.Module):def __init__(self, no_spatial=False):super(TripletAttention, self).__init__()self.cw = AttentionGate()  # 定义宽度方向的注意力门self.hc = AttentionGate()  # 定义高度方向的注意力门self.no_spatial = no_spatial  # 是否忽略空间注意力if not no_spatial:self.hw = AttentionGate()  # 定义空间方向的注意力门def forward(self, x):# 应用注意力门并结合结果x_perm1 = x.permute(0, 2, 1, 3).contiguous()  # 转置以应用宽度方向的注意力x_out1 = self.cw(x_perm1)x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()  # 还原转置x_perm2 = x.permute(0, 3, 2, 1).contiguous()  # 转置以应用高度方向的注意力x_out2 = self.hc(x_perm2)x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()  # 还原转置if not self.no_spatial:x_out = self.hw(x)  # 应用空间注意力x_out = 1 / 3 * (x_out + x_out11 + x_out21)  # 结合三个方向的结果else:x_out = 1 / 2 * (x_out11 + x_out21)  # 结合两个方向的结果(如果no_spatial为True)return x_out# 示例代码
if __name__ == '__main__':input = torch.randn(50, 512, 7, 7)  # 生成随机输入triplet = TripletAttention()  # 实例化TripletAttentionoutput = triplet(input)  # 应用TripletAttentionprint(output.shape)  # 打印输出形状
http://www.xdnf.cn/news/1020565.html

相关文章:

  • 百空间成网 可信数据生态如何重塑数字时代生产关系
  • 基于Docker实现frp之snowdreamtech/frps
  • Linux NFS服务器配置
  • 手阳明大肠经之下廉穴
  • goland 的 dug 设置
  • 我会秘书长杨添天带队赴光明食品集团外高桥食品产业园区考察调研
  • 为何京东与蚂蚁集团竞相申请稳定币牌照?
  • 阿里云服务器操作系统 V3(内核版本 5.10)
  • 数据结构与算法-线性表-线性表的应用
  • electron在单例中实现双击打开文件,并重复打开其他文件
  • leetcode HOT 100(128.连续最长序列)
  • day54 python对抗生成网络
  • C# 结构(属性和字段初始化语句和结构是密封的)
  • C#最佳实践:推荐使用 null 条件运算符调用事件
  • 软考 系统架构设计师系列知识点之杂项集萃(88)
  • 偷懒一下下
  • 在C#中的乐观锁和悲观锁
  • 双碳时代多场景能耗管理实战:数据中心、工业园、商业体如何精准降本?
  • 论坛系统自动化测试
  • C# .NET Core 源代码生成器(dotnet source generators)
  • ROS2编译的理解,与GPT对话
  • 浏览器播放监控画面
  • 【谷歌登录SDK集成】
  • torch 高维矩阵乘法分析,一文说透
  • 信号(瞬时)频率求解与仿真实践(2)
  • 数据库中的Schema是什么?不同数据库中Schema的含义
  • 使用HashMap或者List模拟数据库插入和查询数据
  • 橡胶厂生产线的“协议翻译官”:DeviceNet转Modbus RTU网关实战记
  • PCB 层压板的 Dk 和 Df 表征方法 – 第一部分
  • Linux(Centos 7.6)命令详解:w