当前位置: 首页 > backend >正文

AI笔记 - 模型调试 - 调试方式

模型调试方式

  • 基础信息
  • 打印模型信息
  • 计算参数量和计算量
    • 过滤原则
    • profile方法
    • get_model_complexity_info方法
    • FlopCountAnalysis方法

基础信息

# 打印执行的设备数量:device_count:1
print(f"device_count:{torch.cuda.device_count()}")# 打印当前网络执行的设备信息:device: cuda:0
print(f"device: {next(self.net.parameters()).device}")  # 应该输出: cuda:0

打印模型信息

#操作	    代码示例
#-----------------------------------------------------
#遍历所有模块	for name, module in model.named_modules():
#-----------------------------------------------------
#打印参数详情	module.named_parameters()
#-----------------------------------------------------
#打印缓冲区	module.named_buffers()
#-----------------------------------------------------
#过滤特定层	isinstance(module, nn.Conv2d)
#-----------------------------------------------------
#统计计算量	profile(module, inputs=(input,))
#-----------------------------------------------------import torchvision.models as modelsmodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
for name, p in model.named_parameters():
print(f"params name:{name}, shape:{p.shape}, device:{p.device}")
print(f"dtype: {p.dtype}, 是否需要梯度:{p.requires_grad}")#params name:conv1.weight, shape:torch.Size([64, 3, 7, 7]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.weight, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
#params name:bn1.bias, shape:torch.Size([64]), device:cuda:0
#dtype: torch.float32, 是否需要梯度:True
...for name, module in model.named_modules():print(f"模块名称:{name}, 模块类型:{type(module).__name__}")# 打印可训练参数(weight/bias)for param_name, param in module.named_parameters():print(f"  - 参数:{param_name} | 形状:{param.shape} | 设备:{param.device} | 需梯度:{param.requires_grad} | 数据类型:{param.dtype}")# 打印缓冲区(如BatchNorm的running_mean)for buffer_name, buffer in module.named_buffers():print(f"  - 缓冲区: {buffer_name} | 形状: {buffer.shape} | 设备: {buffer.device}")# 模块名称:, 模块类型:ResNet
#  - 参数:conv1.weight | 形状:torch.Size([64, 3, 7, 7]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.bias | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0
#  - 缓冲区: layer1.0.bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0......
# 模块名称:layer1.0, 模块类型:Bottleneck
#  - 参数:conv1.weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#  - 参数:bn1.weight | 形状:torch.Size([64]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32......
#  - 缓冲区: bn1.running_mean | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.running_var | 形状: torch.Size([64]) | 设备: cuda:0
#  - 缓冲区: bn1.num_batches_tracked | 形状: torch.Size([]) | 设备: cuda:0......
#模块名称:layer1.0.conv1, 模块类型:Conv2d
#  - 参数:weight | 形状:torch.Size([64, 64, 1, 1]) | 设备:cuda:0 | 需梯度:True | 数据类型:torch.float32
#模块名称:layer1.0.bn1, 模块类型:BatchNorm2d
......

计算参数量和计算量

过滤原则

在计算模型计算量(FLOPs)时,过滤掉 BatchNorm2d、Sequential 和 Bottleneck 等非关键层是常见的需求

层类型是否过滤原因
BatchNorm2d✅ 过滤计算量极小(仅逐通道缩放),可忽略
Sequential✅ 过滤容器层(实际计算在子层)
Bottleneck✅ 过滤复合层(计算量已包含在子层中)
Conv2d/Linear❌ 保留核心计算层
ReLU/Pooling⚠️ 可选通常忽略(或单独统计)

profile方法

from thop import profilemodel = models.resnet50(weights=None).cuda()  # 不加载预训练权重以减少下载时间
input = torch.randn(1, 3, 224, 224).cuda()
flops, params = profile(model, inputs=(input,))
print(f"FLOPs: {flops / 1e9:.2f} G")  # 输出: ~4.11 GFLOPs
print(f"Params: {params / 1e6:.2f} M")  # 输出: ~25.56 Million

get_model_complexity_info方法

from ptflops import get_model_complexity_infomacs, params = get_model_complexity_info(self.net,(3, 1280, 1280),  # (channels, height, width)as_strings=True,print_per_layer_stat=True,  # 打印每层计算量verbose=True,
)
print(f"FLOPs: {macs}")
print(f"Params: {params}")# Warning: module IntermediateLayerGetter,FPN,SSH,ClassHead,BboxHead,LandmarkHead,RetinaFace,DataParallel is treated as a zero-op.
# DataParallel(
#  426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#  (module): RetinaFace(
#    426.61 k, 100.000% Params, 4.07 GMac, 99.943% MACs, 
#    (body): IntermediateLayerGetter(
#      213.07 k, 49.946% Params, 1.45 GMac, 35.733% MACs, 
#      (stage1): Sequential(
#        10.13 k, 2.374% Params, 642.25 MMac, 15.774% MACs, 
#        (0): Sequential(
#          232, 0.054% Params, 98.3 MMac, 2.414% MACs, 
#          (0): Conv2d(216, 0.051% Params, 88.47 MMac, 2.173% MACs, 3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#          (1): BatchNorm2d(16, 0.004% Params, 6.55 MMac, 0.161% MACs, 8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#          (2): LeakyReLU(0, 0.000% Params, 3.28 MMac, 0.080% MACs, negative_slope=0.1, inplace=True)
#        )
#        (1): Sequential()
#  ......
# )
# FLOPs: 4.07 GMac
# Params: 426.61 k

FlopCountAnalysis方法

from fvcore.nn import FlopCountAnalysisflops = FlopCountAnalysis(self.net, image)
flops = flops.unsupported_ops_warnings(False)  # 忽略不支持的操作警告# 计算总 FLOPs
print(flops.by_module())  # 打印每个模块的 FLOPs
total_flops = flops.total()
print(f"Total FLOPs: {total_flops / 1e9:.2f} G")# 打印每一层的 FLOPs,返回字典 {模块名: FLOPs}
print(flops.by_module())# 打印按模块分组的 FLOPs
print(flops.by_module_and_operator())  # 更详细的统计
http://www.xdnf.cn/news/9652.html

相关文章:

  • 日常踩坑-pom文件里jdbc配置问题
  • buunctf Crypto-[WUSTCTF2020]情书1
  • 模具制造业数字化转型:精密模塑,以数字之力铸就制造基石
  • 5月28日星期三今日早报简报微语报早读
  • AI任务相关解决方案1-基于NLP的3种模型实现实体识别,以及对比分析(包括基于规则的方法、CRF模型和BERT微调模型)
  • SQL进阶之旅 Day 6:数据更新最佳实践
  • STP协议:如何消除网络环路风暴
  • 【分治】翻转对
  • jsrpc进阶模式 秒杀js前端逆向问题 burp联动进行爆破
  • 【JavaEE】Spring事务
  • c++设计模式-介绍
  • 摩尔条纹 原理以及matlab 实现
  • 数据结构 - 树的遍历
  • 【JavaEE】-- 网络原理
  • NetLink
  • SNTP在电力系统通信中的应用
  • C# NX二次开发-查找连续倒圆角面
  • GB/T 36140-2018 装配式玻纤增强无机材料复合保温墙体检测
  • 【第2章 绘制】2.7 路径、描边与填充
  • 【C++进阶篇】哈希表的模拟实现(赋源码)
  • WSL中ubuntu通过Windows带代理访问github
  • 【razor】采集的同时支持预览和传输的讨论和改造方案探讨
  • DAY38
  • 整合Jdk17+Spring Boot3.2+Elasticsearch9.0+mybatis3.5.12的简单用法
  • 电化学震荡- N 型负微分电阻
  • Android LiveData 详解
  • QT使用cmake添加资源文件闪退,创建了qrc文件不能添加的问题解决
  • 深圳SMT贴片打样全流程优化方案
  • 在监视器(Monitor)内部,是如何做线程同步的?
  • 半桥栅极驱动芯片D2104M使用手册