当前位置: 首页 > ds >正文

YOLOV11改进策略【最新注意力机制】CVPR2025局部区域注意力机制LRSA-增强局部区域特征之间的交互

1.1网络结构

1.2 添加过程

1.2.1 核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrangedef patch_divide(x, step, ps):"""Crop image into patches.Args:x (Tensor): Input feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:crop_x (Tensor): Cropped patches.nh (int): Number of patches along the horizontal direction.nw (int): Number of patches along the vertical direction."""b, c, h, w = x.size()if h == ps and w == ps:step = pscrop_x = []nh = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hnh += 1for j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = wcrop_x.append(x[:, :, top:down, left:right])nw = len(crop_x) // nhcrop_x = torch.stack(crop_x, dim=0)  # (n, b, c, ps, ps)crop_x = crop_x.permute(1, 0, 2, 3, 4).contiguous()  # (b, n, c, ps, ps)return crop_x, nh, nwdef patch_reverse(crop_x, x, step, ps):"""Reverse patches into image.Args:crop_x (Tensor): Cropped patches.x (Tensor): Feature map of shape(b, c, h, w).step (int): Divide step.ps (int): Patch size.Returns:ouput (Tensor): Reversed image."""b, c, h, w = x.size()output = torch.zeros_like(x)index = 0for i in range(0, h + step - ps, step):top = idown = i + psif down > h:top = h - psdown = hfor j in range(0, w + step - ps, step):left = jright = j + psif right > w:left = w - psright = woutput[:, :, top:down, left:right] += crop_x[:, index]index += 1for i in range(step, h + step - ps, step):top = idown = i + ps - stepif top + ps > h:top = h - psoutput[:, :, top:down, :] /= 2for j in range(step, w + step - ps, step):left = jright = j + ps - stepif left + ps > w:left = w - psoutput[:, :, :, left:right] /= 2return outputclass PreNorm(nn.Module):"""Normalization layer.Args:dim (int): Base channels.fn (Module): Module after normalization."""def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class Attention(nn.Module):"""Attention module.Args:dim (int): Base channels.heads (int): Head numbers.qk_dim (int): Channels of query and key."""def __init__(self, dim, heads, qk_dim):super().__init__()self.heads = headsself.dim = dimself.qk_dim = qk_dimself.scale = qk_dim ** -0.5self.to_q = nn.Linear(dim, qk_dim, bias=False)self.to_k = nn.Linear(dim, qk_dim, bias=False)self.to_v = nn.Linear(dim, dim, bias=False)self.proj = nn.Linear(dim, dim, bias=False)def forward(self, x):q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))out = F.scaled_dot_product_attention(q, k, v)  # scaled_dot_product_attention 需要PyTorch2.0之后版本out = rearrange(out, 'b h n d -> b n (h d)')return self.proj(out)class dwconv(nn.Module):def __init__(self, hidden_features, kernel_size=5):super(dwconv, self).__init__()self.depthwise_conv = nn.Sequential(nn.Conv2d(hidden_features, hidden_features, kernel_size=kernel_size, stride=1,padding=(kernel_size - 1) // 2, dilation=1,groups=hidden_features), nn.GELU())self.hidden_features = hidden_featuresdef forward(self, x, x_size):x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous()  # b Ph*Pw cx = self.depthwise_conv(x)x = x.flatten(2).transpose(1, 2).contiguous()return xclass ConvFFN(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, kernel_size=5, act_layer=nn.GELU):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.dwconv = dwconv(hidden_features=hidden_features, kernel_size=kernel_size)self.fc2 = nn.Linear(hidden_features, out_features)def forward(self, x, x_size):x = self.fc1(x)x = self.act(x)x = x + self.dwconv(x, x_size)x = self.fc2(x)return xclass LRSA(nn.Module):"""Attention module.Args:dim (int): Base channels.num (int): Number of blocks.qk_dim (int): Channels of query and key in Attention.mlp_dim (int): Channels of hidden mlp in Mlp.heads (int): Head numbers of Attention."""def __init__(self, dim, qk_dim=36, mlp_dim=96, heads=4):super().__init__()self.layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads, qk_dim)),PreNorm(dim, ConvFFN(dim, mlp_dim))])def forward(self, x, ps=16):step = ps - 2crop_x, nh, nw = patch_divide(x, step, ps)  # (b, n, c, ps, ps)b, n, c, ph, pw = crop_x.shapecrop_x = rearrange(crop_x, 'b n c h w -> (b n) (h w) c')attn, ff = self.layercrop_x = attn(crop_x) + crop_xcrop_x = rearrange(crop_x, '(b n) (h w) c  -> b n c h w', n=n, w=pw)x = patch_reverse(crop_x, x, step, ps)_, _, h, w = x.shapex = rearrange(x, 'b c h w-> b (h w) c')x = ff(x, x_size=(h, w)) + xx = rearrange(x, 'b (h w) c->b c h w', h=h)return x

1.2.2 添加过程

第一步:在nn文件夹下面创建一个新的文件夹,名字叫做Addmodules, 然后再这个文件夹下面创建LRSA.py文件,然后将LRSA的核心代码放入其中。

第二步:继续在Addmodules文件夹下面创建__init__.py文件,输入以下代码:


from .LRSA import LRSA
__all__ = ("LRSA",)

第三步在task.py文件中导入LRSA

from .Addmodules import LRSA
 elif m is LRSA:c1 = ch[f]args = [c1, *args]

第四步:修改yaml文件

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license# Ultralytics YOLO11 object detection model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo11
# Task docs: https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11n.yaml with scale 'n'# [depth, width, max_channels]n: [0.50, 0.25, 1024] # summary: 181 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPss: [0.50, 0.50, 1024] # summary: 181 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPsm: [0.50, 1.00, 512] # summary: 231 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPsl: [1.00, 1.00, 512] # summary: 357 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPsx: [1.00, 1.50, 512] # summary: 357 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs# YOLO11n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 2, C3k2, [256, False, 0.25]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 2, C3k2, [512, False, 0.25]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 2, C3k2, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 2, C3k2, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9- [-1, 2, C2PSA, [1024]] # 10- [-1, 3, LRSA, [1024]]
# YOLO11n head
head:- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 2, C3k2, [512, False]] # 13- [-1, 1, nn.Upsample, [None, 2, "nearest"]]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]] # cat head P4- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]] # cat head P5- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

请多多关注!!!

http://www.xdnf.cn/news/8439.html

相关文章:

  • 使用DDR4控制器实现多通道数据读写(十三)
  • DAO模式
  • DEBUG设置为False 时,django默认的后台样式等静态文件丢失的问题
  • 新能源汽车滑行阻力参数计算全解析:从理论推导到MATLAB工具实现
  • macOS 安装 PostgreSQL
  • 基于大模型的股骨干骨折全周期预测与诊疗方案研究报告
  • 可视化大屏全屏后重载echarts图表
  • JUC并发编程1
  • MyBatis 笔记:parameterType、resultType 与 resultMap 的区别详解
  • Android 网络全栈攻略(四)—— 从 OkHttp 拦截器来看 HTTP 协议一
  • 146. LRU Cache
  • Anthropic公司近日发布了两款新一代大型语言模型Claude Opus 4与Claude Sonnet 4
  • 矩阵:线性代数在AI大模型中的核心支柱
  • 深入解析MySQL中的HAVING关键字:从入门到实战
  • Docker 与 Kubernetes 部署 RabbitMQ 集群(二)
  • C++ 忘掉std::cout吧,fmt和spdlog的结合
  • 达梦数据库-报错-01-[-3205]:全文索引词库加载出错
  • paddle 打包代码 ocr
  • 国产高云FPGA实现MIPI视频解码+图像缩放,基于OV5647摄像头,提供Gowin工程源码和技术支持
  • 04-jenkins学习之旅-java后端项目部署实践
  • 攻略生成模块
  • python邮件地址检验 2024年信息素养大赛复赛/决赛真题 小学组/初中组 python编程挑战赛 真题详细解析
  • C++---vector模拟实现
  • 黑马点评-实现安全秒杀优惠券(使并发一人一单,防止并发超卖)
  • Java桌面应用开发详解:自制截图工具从设计到打包的全流程【附源码与演示】
  • LVS + Keepalived + Nginx 高可用负载均衡系统实验
  • 详解Mysql的 Binlog、UndoLog 和 RedoLog
  • 「金融证券行业」 如何搭建自己的研发智能管理体系?
  • Linux 操作文本文件列数据的常用命令
  • @Column 注解属性详解