(即插即用模块-Attention部分) 六十三、(2024 CVPR) MLKA 多尺度大核注意力
文章目录
- 1、Multi-scale Large Kernel Attention
- 2、代码实现
paper:MULTI-SCALE ATTENTION NETWORK FOR SINGLE IMAGE SUPER-RESOLUTION
Code:https://github.com/icandle/MAN
1、Multi-scale Large Kernel Attention
为了解决如何有效地建立不同区域之间的长距离相关性,并避免由于大卷积核带来的“块效应”问题。这篇论文在 LKA 的基础上提出了一种 多尺度大核注意力(Multi-scale Large Kernel Attention),MLKA 的设计动机是为了解决图像超分辨率任务中,MLKA 结合了 大卷积核分解 和 多尺度机制 来实现这一目标。
MLKA 的实现过程:
- 输入特征图 X: 输入特征图 X 被分解成多个组,每个组包含相同数量的通道。
- LKA 模块: 对每个组应用 LKA 模块,生成不同尺度上的注意力图 LKAi。
- 门控模块: 为了避免扩张卷积带来的“块效应”,对每个组生成的注意力图进行动态重校准。这样可以更好地保留局部纹理信息。通过对每个 LKAi 应用门控模块,生成门控注意力图 MLKAi。
- 聚合: 将所有 MLKAi 聚合,得到最终的注意力图。
MLKA 的优势:
- 更全面的长距离相关性学习: 通过多尺度机制,MLKA 可以学习不同尺度上的长距离相关性,从而更好地恢复图像细节。
- 避免“块效应”: 通过门控机制,MLKA 可以有效地避免扩张卷积带来的“块效应”,从而更好地保留图像的平滑性。
- 计算效率高: MLKA 通过大卷积核分解和门控机制,实现了计算效率的提升。
Multi-scale Large Kernel Attention 结构图:
2、代码实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNorm(nn.Module):def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass MLKA(nn.Module):def __init__(self, n_feats, k=2, squeeze_factor=15):super().__init__()i_feats = 2 * n_featsself.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)# Multiscale Large Kernel Attentionself.LKA7 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA5 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA3 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)self.proj_first = nn.Sequential(nn.Conv2d(n_feats, i_feats, 1, 1, 0))self.proj_last = nn.Sequential(nn.Conv2d(n_feats, n_feats, 1, 1, 0))def forward(self, x, pre_attn=None, RAA=None):shortcut = x.clone()x = self.norm(x)x = self.proj_first(x)a, x = torch.chunk(x, 2, dim=1)a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],dim=1)x = self.proj_last(x * a) * self.scale + shortcutreturn xif __name__ == '__main__':x = torch.randn(4, 360, 64, 64).cuda()model = MLKA(360).cuda()out = model(x)print(out.shape)