(即插即用模块-Attention部分) 六十四、(2024) LSKA 可分离大核注意力
文章目录
- 1、Large Separable Kernel Attention
- 2、代码实现
paper:Large Separable Kernel Attention: Rethinking the Large Kernel Attention Design in CNN
Code:https://github.com/StevenLauHKHK/Large-Separable-Kernel-Attention
1、Large Separable Kernel Attention
VAN 中的 LKA 通过直接使用深度可分离卷积层的大卷积核,导致计算量随卷积核尺寸的增大而呈二次增长,导致模型效率低下。此外,虽然使用深度可分离卷积层的小卷积核和扩张卷积层的大卷积核组合,缓解了计算量增长的问题,但参数量仍然随卷积核尺寸的增大而增长,限制了模型在极端大卷积核下的使用。而这篇论文提出 可分离大核注意力(Large Separable Kernel Attention),LSKA 的提出旨在解决现有大型卷积神经网络(CNN)中,随着卷积核尺寸增大,计算量和内存占用呈二次增长的问题。
LSKA 模块的核心思想是将 2D 卷积核分解为多个 1D 卷积核的组合,从而降低计算量和内存占用。具体来说,LSKA 将深度可分离卷积层的大卷积核分解为水平方向和垂直方向的 1D 卷积核,并依次进行卷积操作,最终得到与 LKA 相同的输出结果。
LSKA 模块可以看作是 LKA 模块的一个变种,它通过将 LKA 中的 2D 卷积核替换为 1D 卷积核的组合来实现。具体步骤如下:
- 水平方向 1D 卷积:使用水平方向的 1D 卷积核对输入特征图进行卷积操作。
- 垂直方向 1D 卷积:使用垂直方向的 1D 卷积核对上一步的输出进行卷积操作。
- 1x1 卷积:使用 1x1 卷积核对上一步的输出进行卷积操作,得到注意力图。
- 特征图融合:将注意力图与输入特征图进行逐元素相乘,得到最终的特征图输出。
LSKA 的优势:
- 计算效率:通过将深度可分离卷积层的大卷积核分解为多个 1D 卷积核的组合,有效降低了参数量和计算量,即使对于非常大的卷积核也能保持高效的计算。
- 性能:LSKA 在保持与 LKA 相当性能的同时,实现了更高的计算效率。
- 可扩展性:LSKA 模块能够有效地扩展到更大的卷积核,而不会牺牲性能,这使得它在处理长距离依赖关系方面具有优势。
Large Separable Kernel Attention 结构图:
2、代码实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LSKA(nn.Module):def __init__(self, dim, k_size):super().__init__()self.k_size = k_sizeif k_size == 7:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,(3-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=((3-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,2), groups=dim, dilation=2)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=(2,0), groups=dim, dilation=2)elif k_size == 11:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 3), stride=(1,1), padding=(0,(3-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(3, 1), stride=(1,1), padding=((3-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,4), groups=dim, dilation=2)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=(4,0), groups=dim, dilation=2)elif k_size == 23:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 7), stride=(1,1), padding=(0,9), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(7, 1), stride=(1,1), padding=(9,0), groups=dim, dilation=3)elif k_size == 35:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 11), stride=(1,1), padding=(0,15), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(11, 1), stride=(1,1), padding=(15,0), groups=dim, dilation=3)elif k_size == 41:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 13), stride=(1,1), padding=(0,18), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(13, 1), stride=(1,1), padding=(18,0), groups=dim, dilation=3)elif k_size == 53:self.conv0h = nn.Conv2d(dim, dim, kernel_size=(1, 5), stride=(1,1), padding=(0,(5-1)//2), groups=dim)self.conv0v = nn.Conv2d(dim, dim, kernel_size=(5, 1), stride=(1,1), padding=((5-1)//2,0), groups=dim)self.conv_spatial_h = nn.Conv2d(dim, dim, kernel_size=(1, 17), stride=(1,1), padding=(0,24), groups=dim, dilation=3)self.conv_spatial_v = nn.Conv2d(dim, dim, kernel_size=(17, 1), stride=(1,1), padding=(24,0), groups=dim, dilation=3)self.conv1 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0h(x)attn = self.conv0v(attn)attn = self.conv_spatial_h(attn)attn = self.conv_spatial_v(attn)attn = self.conv1(attn)return u * attnif __name__ == '__main__':x = torch.randn(4, 64, 128, 128).cuda()model = LSKA(64, 7).cuda()out = model(x)print(out.shape)