当前位置: 首页 > news >正文

(即插即用模块-Attention部分) 六十四、(2024) LSKA 可分离大核注意力

在这里插入图片描述

文章目录

  • 1、Large Separable Kernel Attention
  • 2、代码实现

paper:Large Separable Kernel Attention: Rethinking the Large Kernel Attention Design in CNN

Code:https://github.com/StevenLauHKHK/Large-Separable-Kernel-Attention


1、Large Separable Kernel Attention

VAN 中的 LKA 通过直接使用深度可分离卷积层的大卷积核,导致计算量随卷积核尺寸的增大而呈二次增长,导致模型效率低下。此外,虽然使用深度可分离卷积层的小卷积核和扩张卷积层的大卷积核组合,缓解了计算量增长的问题,但参数量仍然随卷积核尺寸的增大而增长,限制了模型在极端大卷积核下的使用。而这篇论文提出 可分离大核注意力(Large Separable Kernel Attention),LSKA 的提出旨在解决现有大型卷积神经网络(CNN)中,随着卷积核尺寸增大,计算量和内存占用呈二次增长的问题。

LSKA 模块的核心思想是将 2D 卷积核分解为多个 1D 卷积核的组合,从而降低计算量和内存占用。具体来说,LSKA 将深度可分离卷积层的大卷积核分解为水平方向和垂直方向的 1D 卷积核,并依次进行卷积操作,最终得到与 LKA 相同的输出结果。

LSKA 模块可以看作是 LKA 模块的一个变种,它通过将 LKA 中的 2D 卷积核替换为 1D 卷积核的组合来实现。具体步骤如下:

  1. 水平方向 1D 卷积:使用水平方向的 1D 卷积核对输入特征图进行卷积操作。
  2. 垂直方向 1D 卷积:使用垂直方向的 1D 卷积核对上一步的输出进行卷积操作。
  3. 1x1 卷积:使用 1x1 卷积核对上一步的输出进行卷积操作,得到注意力图。
  4. 特征图融合:将注意力图与输入特征图进行逐元素相乘,得到最终的特征图输出。

LSKA 的优势:

  • 计算效率:通过将深度可分离卷积层的大卷积核分解为多个 1D 卷积核的组合,有效降低了参数量和计算量,即使对于非常大的卷积核也能保持高效的计算。
  • 性能:LSKA 在保持与 LKA 相当性能的同时,实现了更高的计算效率。
  • 可扩展性:LSKA 模块能够有效地扩展到更大的卷积核,而不会牺牲性能,这使得它在处理长距离依赖关系方面具有优势。

Large Separable Kernel Attention 结构图:
在这里插入图片描述


2、代码实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LSKA(nn.Module):def __init__(self, dim, k_size):super().__init__()self.k_size = k_sizeif k_size == 7:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,(3-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=((3-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,2), groups=dim, dilation=2)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=(2,0), groups=dim, dilation=2)elif k_size == 11:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,(3-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=((3-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,4), groups=dim, dilation=2)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=(4,0), groups=dim, dilation=2)elif k_size == 23:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 7), stride=(1,1), padding=(0,9), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(7, 1), stride=(1,1), padding=(9,0), groups=dim, dilation=3)elif k_size == 35:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 11), stride=(1,1), padding=(0,15), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(11, 1), stride=(1,1), padding=(15,0), groups=dim, dilation=3)elif k_size == 41:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 13), stride=(1,1), padding=(0,18), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(13, 1), stride=(1,1), padding=(18,0), groups=dim, dilation=3)elif k_size == 53:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 17), stride=(1,1), padding=(0,24), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(17, 1), stride=(1,1), padding=(24,0), groups=dim, dilation=3)self.conv1 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0h(x)attn = self.conv0v(attn)attn = self.conv_spatial_h(attn)attn = self.conv_spatial_v(attn)attn = self.conv1(attn)return u * attnif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()model = LSKA(64, 7).cuda()out = model(x)print(out.shape)
http://www.xdnf.cn/news/275779.html

相关文章:

  • ubuntu-PyQt5安装+PyCharm配置QtDesigner + QtUIC
  • 关于离散化算法的看法与感悟
  • 软考-软件设计师中级备考 8、进程管理
  • 49认知干货:产品的生命周期及类型汇总
  • 【Java项目脚手架系列】第一篇:Maven基础项目脚手架
  • Rust的安全卫生原则
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】2.2 多表关联技术(INNER JOIN/LEFT JOIN/FULL JOIN)
  • C++八股--6--mysql 日志与并发控制
  • WSL在D盘安装Ubuntu
  • 纯文本Text转Html网页转换器
  • 方案精读:110页华为云数据中心解决方案技术方案【附全文阅读】
  • 项目收尾管理
  • 时序分解 | Matlab基于WOA-MVMD鲸鱼算法优化多元变分模态分解
  • C盘莫名其妙一直变大
  • 智能工厂边缘计算:从数据采集到实时决策
  • WPF之尺寸属性层次
  • 如何从GitHub上调研优秀的开源项目,并魔改应用于工作中?
  • 【言语理解】中心理解题目之选项分析
  • Unity与Unreal Engine(UE)的深度解析及高级用法
  • 【AI面试准备】模型自动化评估经验
  • MCP协议与Dify集成教程
  • 华中科技大学系统结构慕课部分答案
  • 33.降速提高EMC能力
  • 深度学习中的数据增强:提升食物图像分类模型性能的关键策略
  • 【自存】python使用matplotlib正常显示中文、负号
  • Android ART运行时无缝替换Dalvik虚拟机的过程分析
  • Android运行时ART加载OAT文件的过程
  • 跨学科项目式学习的AI脚手架设计:理论框架与实践路径研究
  • 从头训练小模型: 4 lora 微调
  • 【51单片机6位数码管显示时间与秒表】2022-5-8