当前位置: 首页 > ai >正文

torchsummary库中的summary()函数

torchsummary库中的summary()函数是PyTorch中用于可视化模型结构的核心工具,其作用类似于TensorFlow的model.summary()。它通过生成详细的表格输出,帮助开发者直观理解模型层次、参数分布和计算资源需求。以下是其核心功能详解:


📊 ​​1. 核心功能​

  • ​模型结构可视化​
    输出包含每一层的类型(如Conv2dLinear)、名称、输出张量形状(Output Shape)和参数量(Param #)。

    例如:

    其中Output Shape-1表示动态的批量大小(batch size),后续维度为特征图或向量的形状。

  • ​参数量统计​
    汇总总参数量(Total params)、可训练参数(Trainable params)及不可训练参数(如冻结层)。

  • ​内存占用分析​
    计算模型的内存开销,包括:

    • 输入数据占用(Input size (MB)
    • 前向/反向传播中间变量占用(Forward/backward pass size (MB)
    • 参数存储占用(Params size (MB)
    • 预估总内存(Estimated Total Size (MB))。

⚙️ ​​2. 使用方法​

​安装​
pip install torchsummary -i https://mirrors.aliyun.com/pypi/simple/
​代码示例​
from torchsummary import summary
import torch.nn as nn# 定义模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(3, 16, kernel_size=3)self.fc = nn.Linear(16 * 30 * 30, 10)  # 假设输入32x32,卷积后尺寸为30x30def forward(self, x):x = self.conv(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 实例化并调用summary
model = SimpleModel()
summary(model, input_size=(3, 32, 32), device="cpu")  # 指定输入尺寸和设备
​参数说明​
  • model:继承nn.Module的PyTorch模型。
  • input_size:输入张量形状(C, H, W),​​不含batch size​​(自动添加-1占位)。
  • device:可选"cuda""cpu",​​必须与模型所在设备一致​​,否则报错(如RuntimeError: Input type and weight type should be the same)。
  • batch_size:可选,控制输出形状中的批量占位符(默认为-1)。

🚨 ​​3. 常见问题与注意事项​

  1. ​设备匹配​
    若模型在CPU上,需显式设置device="cpu",否则默认使用GPU(device="cuda")会引发类型错误。

  2. ​输入尺寸要求​
    input_size需与模型实际输入一致。例如:

    • RGB图像:(3, H, W)
    • 灰度图:(1, H, W)
    • 全连接网络:(input_dim,)(如(784,)对应MNIST展平后向量)。
  3. ​动态结构支持​
    若模型前向传播包含条件分支或动态操作(如x.view()),需确保输入尺寸与view/flatten操作兼容,否则输出形状可能计算错误。

  4. ​输出解读​

    • Output Shape中的[-1, C, H, W]:卷积/池化层输出。
    • [-1, D]:全连接层输出(D为特征维度)。

💡 ​​4. 典型应用场景​

  • ​模型调试​​:快速验证各层输出尺寸是否匹配,避免维度不匹配错误。
  • ​复杂度评估​​:通过参数量和内存占用优化模型结构(如减少冗余层)。
  • ​论文/报告展示​​:生成简洁的架构摘要表格。

🌰 ​​输出示例解析​


  • ​参数量计算​​:
    卷积层:(3×3×3+1)*16 = 448(权重+偏置)
    全连接层:(16×30×30+1)*10 = 144,010。
  • ​内存估算​​:帮助预判模型在边缘设备的部署可行性。

通过summary(),开发者无需逐层打印调试即可全局掌握模型结构,显著提升开发效率。尤其适合需要快速迭代模型或资源受限的场景。

http://www.xdnf.cn/news/14191.html

相关文章:

  • Kerberos快速入门之基本概念与认证基本原理
  • OpenLayers 创建坐标系统
  • Flower框架中noise_multiplier与clipped_count_stddev的关系
  • [智能客服project] AI代理系统 | 意图路由器
  • pikachu靶场通关笔记30 文件包含01之本地文件包含
  • Typecho安装后后台 404 报错解决
  • CMake实践: 以开源库QSimpleUpdater为例,详细讲解编译、查找依赖等全过程
  • Reqable・API 抓包调试 + API 测试一站式工具
  • 17_Flask部署到网络服务器
  • 【软测】接口测试 - 用postman测试软件登录模块
  • 微机原理与接口技术,期末冲刺复习资料(汇总版)
  • Linux进程间通信(IPC)详解:从入门到理解
  • H5 技术与定制开发工具融合下的数字化营销新探索
  • 高效录屏工具推荐:从系统自带到专业进阶
  • 函数调用过程中的栈帧变化
  • 普通Dom转换为可拖拽、最大化、最小化窗口js插件
  • 【在线五子棋对战】六、项目结构设计 工具模块实现
  • 【unitrix】 1.6 数值类型基本结构体(types.rs)
  • 商用油烟净化器日常维护的标准化流程
  • Arduino入门教程:4-1、代码基础-进阶
  • 静态变量详解(static variable)
  • 微博项目(总体搭建)
  • Javascript什么是原型和原型链,八股文
  • java面试总结-20250609
  • 数据结构 学习 图 2025年6月14日 12点57分
  • spring如何处理bean的循环依赖
  • NuttX 调度器源码学习
  • 吃透 Golang 基础:方法
  • 湖南源点(市场研究)咨询 DNF下沉市场用户研究项目之调研后感
  • 03、继承与多态