当前位置: 首页 > ops >正文

Day 47 注意力热图可视化

@浙大疏锦行

这个注意力热图是通过钩子机制 register_forward_hook 捕获最后一个卷积层(conv3)的输出特征图。

  1. 通道权重计算:对特征图的每个通道进行全局平均池化,得到通道重要性权重。
  2. 热力图生成:将高权重通道的特征图缩放至原始图像尺寸,与原图叠加显示。

热力图(红色表示高关注,蓝色表示低关注)半透明覆盖在原图上。主要从以下方面理解:

  • 高关注区域(红色):模型认为对分类最重要的区域。
    例如:
    • 在识别 “狗” 时,热力图可能聚焦狗的面部、身体轮廓或特征性纹理。
    • 若热力图错误聚焦背景(如红色区域在无关物体上),可能表示模型过拟合或训练不足。

多通道对比

  • 不同通道关注不同特征
    例如:
    • 通道 1 可能关注整体轮廓,通道 2 关注纹理细节,通道 3 关注颜色分布。
    • 结合多个通道的热力图,可全面理解模型的决策逻辑。

可以帮助解释

  • 检查模型是否关注正确区域(如识别狗时,是否聚焦狗而非背景)。
  • 发现数据标注问题(如标签错误、图像噪声)。
  • 向非技术人员解释模型决策依据(如 “模型认为这是狗,因为关注了眼睛和嘴巴”)。
# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):"""可视化模型的注意力热力图,展示模型关注的图像区域"""model.eval()  # 设置为评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples:  # 只可视化前几个样本breakimages, labels = images.to(device), labels.to(device)# 创建一个钩子,捕获中间特征图activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 为最后一个卷积层注册钩子(获取特征图)hook_handle = model.conv3.register_forward_hook(hook)# 前向传播,触发钩子outputs = model(images)# 移除钩子hook_handle.remove()# 获取预测结果_, predicted = torch.max(outputs, 1)# 获取原始图像img = images[0].cpu().permute(1, 2, 0).numpy()# 反标准化处理img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)# 获取激活图(最后一个卷积层的输出)feature_map = activation_maps[0][0].cpu()  # 取第一个样本# 计算通道注意力权重(使用SE模块的全局平均池化)channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]# 按权重对通道排序sorted_indices = torch.argsort(channel_weights, descending=True)# 创建子图fig, axes = plt.subplots(1, 4, figsize=(16, 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')axes[0].axis('off')# 显示前3个最活跃通道的热力图for j in range(3):channel_idx = sorted_indices[j]# 获取对应通道的特征图channel_map = feature_map[channel_idx].numpy()# 归一化到[0,1]channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)# 调整热力图大小以匹配原始图像from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 显示热力图axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')axes[j+1].axis('off')plt.tight_layout()plt.show()# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

http://www.xdnf.cn/news/20170.html

相关文章:

  • 工作后的总结和反思4
  • SQL 入门指南:排序与分页查询(ORDER BY 多字段排序、LIMIT 分页实战)
  • 使用Shell脚本实现Linux系统资源监控邮件告警
  • 永磁同步电机 FOC 控制中 d、q 轴杂谈与角度偏移影响
  • 使用Ansible自动化部署Hadoop集群(含源码)--环境准备
  • 【Android】ViewPager2结合Fragment实现多页面滑动切换
  • 百度竞价推广:搜索竞价信息流推广代运营
  • ElementUI之Upload 上传的使用
  • C++语法之--多态
  • 了解Python
  • Ubuntu:Git SSH密钥配置的完整流程
  • 捷多邦揭秘超厚铜板:从制造工艺到设计关键环节​
  • 让字符串变成回文串的最少插入次数-二维dp
  • 单元测试详解
  • 基于树莓派与Jetson Nano集群的实验边缘设备上视觉语言模型(VLMs)的性能评估与实践探索
  • 【c++进阶系列】:万字详解AVL树(附源码实现)
  • ubuntu 系統使用過程中黑屏問題分析
  • 前端上传切片优化以及实现
  • 基于LLM开发Agent应用开发问题总结
  • equals 定义不一致导致list contains错误
  • SQL面试题及详细答案150道(81-100) --- 子查询篇
  • webrtc弱网-LossBasedBandwidthEstimation类源码分析与算法原理
  • 【Proteus仿真】定时器控制系列仿真——秒表计数/数码管显示时间
  • 【ComfyUI】混合 ControlNet 多模型组合控制生成
  • ANSYS HFSS边界条件的认识
  • 【LeetCode热题100道笔记】二叉树中的最大路径和
  • 9.FusionAccess桌面云
  • Spring的事件监听机制(一)
  • 03.缓存池
  • 【数学建模】质量消光系数在烟幕遮蔽效能建模中的核心作用