当前位置: 首页 > ai >正文

(即插即用模块-Attention部分) 六十三、(2024 CVPR) MLKA 多尺度大核注意力

在这里插入图片描述

文章目录

  • 1、Multi-scale Large Kernel Attention
  • 2、代码实现

paper:MULTI-SCALE ATTENTION NETWORK FOR SINGLE IMAGE SUPER-RESOLUTION

Code:https://github.com/icandle/MAN


1、Multi-scale Large Kernel Attention

为了解决如何有效地建立不同区域之间的长距离相关性,并避免由于大卷积核带来的“块效应”问题。这篇论文在 LKA 的基础上提出了一种 多尺度大核注意力(Multi-scale Large Kernel Attention),MLKA 的设计动机是为了解决图像超分辨率任务中,MLKA 结合了 大卷积核分解 和 多尺度机制 来实现这一目标。

MLKA 的实现过程:

  1. 输入特征图 X: 输入特征图 X 被分解成多个组,每个组包含相同数量的通道。
  2. LKA 模块: 对每个组应用 LKA 模块,生成不同尺度上的注意力图 LKAi。
  3. 门控模块: 为了避免扩张卷积带来的“块效应”,对每个组生成的注意力图进行动态重校准。这样可以更好地保留局部纹理信息。通过对每个 LKAi 应用门控模块,生成门控注意力图 MLKAi。
  4. 聚合: 将所有 MLKAi 聚合,得到最终的注意力图。

MLKA 的优势:

  • 更全面的长距离相关性学习: 通过多尺度机制,MLKA 可以学习不同尺度上的长距离相关性,从而更好地恢复图像细节。
  • 避免“块效应”: 通过门控机制,MLKA 可以有效地避免扩张卷积带来的“块效应”,从而更好地保留图像的平滑性。
  • 计算效率高: MLKA 通过大卷积核分解和门控机制,实现了计算效率的提升。

Multi-scale Large Kernel Attention 结构图:
在这里插入图片描述


2、代码实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass LayerNorm(nn.Module):def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass MLKA(nn.Module):def __init__(self, n_feats, k=2, squeeze_factor=15):super().__init__()i_feats = 2 * n_featsself.norm = LayerNorm(n_feats, data_format='channels_first')self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)# Multiscale Large Kernel Attentionself.LKA7 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA5 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.LKA3 = nn.Sequential(nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)self.proj_first = nn.Sequential(nn.Conv2d(n_feats, i_feats, 1, 1, 0))self.proj_last = nn.Sequential(nn.Conv2d(n_feats, n_feats, 1, 1, 0))def forward(self, x, pre_attn=None, RAA=None):shortcut = x.clone()x = self.norm(x)x = self.proj_first(x)a, x = torch.chunk(x, 2, dim=1)a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],dim=1)x = self.proj_last(x * a) * self.scale + shortcutreturn xif __name__ == '__main__':x = torch.randn(4, 360, 64, 64).cuda()model = MLKA(360).cuda()out = model(x)print(out.shape)
http://www.xdnf.cn/news/3595.html

相关文章:

  • 计算机视觉与深度学习 | 什么是图像金字塔?
  • 如何用CSS实现HTML元素的旋转效果:从基础到高阶应用
  • SQL ROUND() 函数详解
  • MySQL基础关键_006_DQL(五)
  • 数智图书馆的信息组织革命:AI变革下的新秩序
  • Spring 事务的底层原理常见陷阱
  • Fabrice Bellard(个人网站:‌bellard.org‌)介绍
  • ad 多通道设计中出现的相关问题
  • AWS上构建基于自然语言和LINDO API的线性规划与非线性规划的优化计算系统
  • MCP 探索:MCP 集成的相关网站 Smithery、PulseMCP 等
  • Java面试趣事:从死循环到分段锁
  • Lua 基础 API与 辅助库函数 中关于创建的方法用法
  • 基于STM32的智能摇头风扇设计(WIFI+语音控制)
  • CGAL:最小包围圆
  • 共铸价值:RWA 联合曲线价值模型,撬动现实资产生态
  • 基于机器学习的心脏病数据分析与可视化(百度智能云千帆AI+DeepSeek人工智能+机器学习)健康预测、风险评估与数据可视化 健康管理平台 数据分析与处理
  • k8s 探针
  • 基于ArduinoIDE的任意型号单片机 + GPS北斗BDS卫星定位
  • 基于「骑手外卖系统」串联7大设计原则
  • 【Hot 100】 146. LRU 缓存
  • Three.js在vue中的使用(二)-加载、控制
  • 【ICMP协议深度解析】从网络诊断到安全实践
  • Mysql常用语句汇总
  • centos7.0无法安装php8.2/8.3
  • ROS2学习笔记|创建工作空间并打印文件内容
  • 视频编解码学习二之颜色科学
  • UDP / TCP 协议
  • 使用DeepSeek协助恢复历史数据
  • GoFrame 奉孝学习笔记
  • ElasticSearch深入解析(十):字段膨胀(Mapping 爆炸)问题的解决思路