当前位置: 首页 > java >正文

特征图可视化代码

  • 进行特征图可视化的时候,修改模型的forward函数来进行可视化十分麻烦,还需要想办法把特征图传出来,在模型层层调用的时候更加麻烦,要修改多个无关的嵌套,还容易引起bug。这里提供了一个简单的范式,只需要一个vis.py文件(可从train.py或者test.py修改而来),无需修改模型的定义文件,即可实现特征图的可视化。
  • 该做法的核心思想是两点,第一点是利用vis.py里面的全局变量来存储特征图以及网络层数等,第二点是直接在vis.py里面重写需要可视化特征图的module的forward函数,以用最小的改动将特征图传递出来。
  • 这段代码还提供了利用sns.heatmap可视化特征图的例子,整体代码如下:
# vis talor mod
import argparse
import os
import math
from functools import partialimport yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdmimport datasets
import models
import utils
from torchvision import transforms
from PIL import Image
import random
import numpy as npimport matplotlib.pyplot as plt
import seaborn as sns
from models.models_meta import mGAttn
from einops import rearrange as rearrange@torch.no_grad()
def vis_mod(mod_dict, path, name):for i in [3, 13]:#[3,5,19]:scale = mod_dict[f'layer_{i}_scale'] offset = mod_dict[f'layer_{i}_offset']name_scale_i = name+f'scale_{i}_avg.png'vis_feature(scale, path, name_scale_i)vis_feature_each_channel(scale, path, name_scale_i)name_offset_i = name+f'offset_{i}_avg.png'vis_feature(offset, path, name_offset_i)   vis_feature_each_channel(offset, path, name_offset_i)   return def vis_feature(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature, dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name))plt.close()for i in range(feature.size(0)//32):plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(torch.mean(feature[i*32:(i+1)*32, :, :], dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'avg_{i}')))plt.close()def vis_feature_each_channel(feature, path, name):feature = feature[0, ...]# lower_percentile = 0.1# upper_percentile = 0.9# for i in range(feature.shape[0]):#     feature_i = feature[i].view(-1)#     lower_bound = torch.quantile(feature_i, lower_percentile)#     upper_bound = torch.quantile(feature_i, upper_percentile)#     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)for i in range(feature.shape[0]):# plt.figure()plt.figure(figsize=(1.58, 1.58))ax = sns.heatmap(feature[i].cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')ax.tick_params(axis='both', which='both', length=0)plt.tight_layout()plt.savefig(os.path.join(path, name.replace('avg', f'channel_{i}')))plt.close()global_feature_maps = {}
def modify_forward_for_mGAttn(module):if isinstance(module, mGAttn):# original_forward = module.forwarddef modified_forward(self, x):"""x: b * c * h * w"""# 这里省略了模型原有的一些forward过程curr_layer = global_feature_maps['curr_layer']if curr_layer in [3, 13]:#[1,5,19]:B, h, Ch, N = offset.shapeglobal_feature_maps[f'layer_{curr_layer}'] = feature.view(B, h, He, We)global_feature_maps['curr_layer'] = curr_layer+1# 这里省略了模型原有的一些forward过程return outmodule.forward =  modified_forward.__get__(module)for child_module in module.children():modify_forward_for_mGAttn(child_module)if __name__ == '__main__':# 这里省略了一些模型的定义过程# modify forward for mGAttnmodify_forward_for_mGAttn(model)# 接着按自己的方式直接调用模型即可myeval(model)
http://www.xdnf.cn/news/10251.html

相关文章:

  • 数据库核心技术深度剖析:事务、索引、锁与SQL优化实战指南(第四节)----从行级锁到死锁处理的系统梳理
  • WIN11+CUDA11.8+VS2019配置BundleFusion
  • Linux之MySQL安装篇
  • Redis主从复制详解
  • 扫一扫的时候会经历哪些事
  • 华为OD机试真题——模拟消息队列(2025A卷:100分)Java/python/JavaScript/C++/C语言/GO六种最佳实现
  • 哪些工作最容易被AI取代?
  • C++基础算法————深度优先搜索(DFS)
  • 【速通RAG实战:进阶】17、AI视频打点全攻略:从技术实现到媒体工作流提效的实战指南
  • 嵌入式(C语言篇)Day13
  • Go语言事件总线EventBus本地事件总线系统的完整实现框架
  • Angularjs-Hello
  • Java中的引用类型以及区别的特点
  • 复数三角不等式简介及 MATLAB 演示
  • 电脑用户名是中文,conda配置环境报错,该怎么解决
  • SpringBoot网络请求RestTemplate Util工具类
  • Kerberos面试内容整理-会话密钥的协商与使用
  • WIN11+eclipse搭建java开发环境
  • 端午安康(Python)
  • C++深入类与对象
  • 电脑重装或者开机出现错误
  • 【harbor】--基础使用
  • 利用aqs构建一个自己的非公平独占锁
  • 【数据集】全球无缝高分辨率1 km 月均地表温度和气温(2001-2020)
  • 小白的进阶之路系列之八----人工智能从初步到精通pytorch综合运用的讲解第一部分
  • 【C++】 类和对象(上)
  • Matlab数据类型
  • 界面形成能的理解
  • Python简易音乐播放器开发教程
  • day61—DFS—省份数量(LeetCode-547)