当前位置: 首页 > java >正文

PyTorch模型保存方式

PyTorch提供两种主流模型保存方式和一种训练断点保存与恢复的方法。

1. 仅保存模型参数(推荐)

# 保存
torch.save(model.state_dict(), "model_params.pth")  # 加载
new_model = TheModelClass()  
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()

核心优势:

  • 文件体积小(仅参数数据)

  • 避免PyTorch版本兼容问题

  • 支持跨模型结构迁移(需设置strict=False

2. 保存完整模型对象
# 保存
torch.save(model, "full_model.pth")  # 加载   loaded_model = torch.load("full_model.pth")
loaded_model.eval()

适用场景:

  • 快速原型验证

  • 模型结构包含动态逻辑(如自定义前向传播)

3. 训练断点保存与恢复
# 保存检查点
checkpoint = {'epoch': current_epoch,'model_state': model.state_dict(),'optimizer_state': optimizer.state_dict(),'loss': loss_value
}
torch.save(checkpoint, "checkpoint.tar")# 恢复训练
model = TheModelClass()
optimizer = torch.optim.Adam(model.parameters())
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
model.train()  # 保持训练模式

关键细节:

  • 推荐使用.tar后缀区分普通参数文件

  • 自动恢复学习率调度器等训练状态

http://www.xdnf.cn/news/7318.html

相关文章:

  • C++ —— Lambda 表达式
  • 虚拟地址空间
  • 第四章、SKRL(1): Examples
  • Python实例题:Python 实现简易 Shell
  • Python的传参过程的小细节
  • 什么是5G前传、中传、回传?
  • 数据分析—Excel数据清洗函数
  • Compose Kotlin Multiplatform跨平台基础运行
  • CM0启动CM7_0、CM7_1注意事项
  • PCB设计教程【入门篇】——电路分析基础-基本元件(电阻电容电感)
  • Docker 入门指南:从安装配置到核心概念解析
  • [ 计算机网络 ] | 宏观谈谈计算机网络
  • 十三、Hive 行列转换
  • 计算机视觉与深度学习 | Python实现ARIMA-WOA-CNN-LSTM时间序列预测(完整源码和数据
  • netcore项目使用winforms与blazor结合来开发如何按F12,可以调出chrome devtool工具辅助开发
  • 通过低功耗蓝牙通信实例讲透 MCU 各个定时器
  • AT 指令详解:基于 MCU 的通信控制实战指南AT 指令详解
  • ESP32开发-两个WIFI设备的通讯搭建
  • AI大模型从0到1记录学习numpy pandas day25
  • 无人设备遥控器之数据压缩与编码技术篇
  • PLC组网的方法、要点及实施全解析
  • android13以太网静态ip不断断开连上问题
  • C++(24):容器类<list>
  • Unreal 从入门到精通之SceneCaptureComponent2D实现UI层3D物体360°预览
  • MAC常用操作整理
  • MAC电脑中右键后复制和拷贝的区别
  • C++:与7无关的数
  • 基于 Vue 和 Node.js 实现图片上传功能:从前端到后端的完整实践
  • 汽车零部件的EMI抗扰性测试
  • Java中的流详解