当前位置: 首页 > news >正文

空间注意力机制

知识点:

空间注意力机制 spatial attention SA;

SA 中平均池化和最大池化的操作;

torch.max;


参考博客:通俗易懂理解通道注意力机制(CAM)与空间注意力机制(SAM)-CSDN博客

 


空间注意力机制代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SpatialAttention(nn.Module):def __init__(self,kernel_size=7):"""初始化空间注意力模块Args:kernel_size (int): 卷积核大小,通常为7x7"""super().__init__()# 确保kernel_size是奇数,以便paddingassert kernel_size % 2 ==1padding = kernel_size // 2self.sigmoid = nn.Sigmoid()# 定义7x7卷积层,输入通道为2(平均池化和最大池化的结果),输出通道为1self.conv = nn.Conv2d(in_channels=2,  # 输入通道数为2(平均池化和最大池化的结果)out_channels=1, # 输出通道数为1(生成空间注意力图)kernel_size=kernel_size,  # 卷积核大小,通常为7x7padding=padding,   # 填充,保持特征图大小不变bias=False # 不使用偏置)def forward(self, x):"""前向传播Args:x (torch.Tensor): 输入特征图 [B, C, H, W]Returns:torch.Tensor: 经过空间注意力加权后的特征图"""# 沿着通道维度进行平均池化和最大池化avg_pool = torch.mean(x, dim=1, keepdim=True) # F_avg^s [B,1,H,W]# 注意这里返回值是两个,最大值和索引,要用两个参数接max_pool,_ = torch.max(x, dim=1, keepdim=True)  # F_max^s [B,1,H,W]# 拼接平均池化和最大池化的结果pooled_features = torch.cat((avg_pool, max_pool), dim=1)  # [B,2,H,W]# 通过 7 * 7 卷积层处理spatial_attention = self.conv(pooled_features)# sigmoid激活spatial_attention = self.sigmoid(spatial_attention)return x * spatial_attentionif __name__ == '__main__':# 创建测试数据batch_size=2channels=3height=64width = 64x = torch.randn(batch_size, channels, height, width)sa=SpatialAttention(kernel_size=7)outputs=sa(x)print(f"input shape:{x.shape}")print(f"output shape:{outputs.shape}")

沿通道维度的平均池化

avg_pool = torch.mean(x, dim=1, keepdim=True) # F_avg^s [B,1,H,W]

沿通道维度的最大池化

 max_values, _ = torch.max(x, dim=1, keepdim=True)  # F_max^s [B,1,H,W]

注意这里返回是两个值,最大值索引也返回了,必须要用两个参数接!!!

vs 通道注意力机制中的池化操作

 

http://www.xdnf.cn/news/962281.html

相关文章:

  • uniapp开发小程序vendor.js 过大
  • 使用java实现蒙特卡洛模拟风险预测功能
  • AI一周事件(2025年6月3日-6月9日)
  • WHAT - 组件库单入口打包和多入口打包
  • “液态玻璃”难解苹果AI焦虑:WWDC25背后的信任危机
  • 自动化三维扫描检测赋能汽车铸造件高效检测
  • 笔记 操作系统复习
  • 供应链管理-物流:自动驾驶分为几个级别/L0无自动化/L1驾驶辅助/L2部分自动化/L3有条件自动化/L4高度自动化/L5完全自动化
  • 云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?
  • SDC命令详解:使用uniquify命令进行唯一化
  • 菲尔斯特传感器,超声波风速风向传感器助力绿色能源发展
  • idea中黄色感叹号打开
  • RPC调用三 使用代理进行服务自动注册
  • CppCon 2015 学习:RapidCheck Property based testing for C++
  • 计算机基础(一):ASCll、GB2312、GBK、Unicode、UTF-32、UTF-16、UTF-8深度解析
  • 记录chrome浏览器的一个bug
  • 零基础入门 线性代数
  • 上位机开发过程中的设计模式体会(2):观察者模式和Qt信号槽机制
  • 经典的多位gpio初始化操作
  • 基于FPGA的PID算法学习———实现PI比例控制算法
  • React Native 基础语法与核心组件:深入指南
  • 篇章三 论坛系统——环境搭建
  • 如何将数据从 iPhone 传输到笔记本电脑
  • ACM70V-701-2PL-TL00
  • CPP基础(2)
  • Linux 删除登录痕迹
  • rapidocr v3.1.0发布
  • 什么样的登录方式才是最安全的?
  • 高频交易技术:订单簿分析与低延迟架构——从Level 2数据挖掘到FPGA硬件加速的全链路解决方案
  • Numpy7——数学2(矩阵基础,线性方程基础)