当前位置: 首页 > java >正文

Pytorch 的模型保存

预加载模型及存储

import torchvision.models as models
models.resnet18(pretrained=True)

加载预训练好的模型,如未在路径中找到对应模型,会自动从网上下载。

默认的模型保存路径:C:\Users\Administrator\.torch;

如需修改模型保存路径,可在程序脚本中制定,如:os.environ['TORCH_HOME'] = 'D:\\PyTorch'

也可以在环境变量中为‘TORCH_HOME’配置路径;模型下载好后,程序会在指定路径下的 models 文件夹中加载对应的模型,下载的时候如没有这个文件夹会自动创建。

Pytorch 模型保存

保存模型权重

模型权重文件较小,与模型的定义解耦,便于在不同环境中加载。

torch.save(model.state_dict(), 'model_weights.pth')

加载模型权重:加载时要先定义相同的模型结构,再加载权重:
例如,模型结构如下
model = resnet18(pretrained=False)  # 使用与保存时相同的模型架构
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes) 

# 加载模型权重
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 切换到评估模式

保存整个模型

包括模型的结构和权重,加载时不需要重新定义模型结构,但文件较大。
保存整个模型:torch.save(model, 'model_full.pth')

加载整个模型:model = torch.load('model_full.pth')

保存和加载模型时的设备

在保存和加载模型时,要保持CPU 或 GPU的一致性。如果保存时使用的是 GPU,加载时也需要使用 GPU,反之亦然。如要在不同设备之间加载模型,可以使用 `map_location` 参数。

在 CPU 上加载 GPU 训练的模型

加载模型时指定设备
model = torch.load('model_full.pth', map_location=torch.device('cpu'))

在 GPU 上加载 CPU 训练的模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('model_full.pth', map_location=device)
model.to(device).eval()

保存训练状态

在保存模型时,还可以保存优化器的状态,以便在训练中断后恢复训练。
  torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss,
  }, 'checkpoint.pth')

加载:
  checkpoint = torch.load('checkpoint.pth')
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss = checkpoint['loss']

http://www.xdnf.cn/news/4179.html

相关文章:

  • 数据结构(一)——线性表的顺序表示和实现
  • k8s术语之service
  • k8s pod request/limit 值不带单位会发生什么?
  • 浅谈 - GPTQ为啥按列量化
  • NGINX `ngx_http_browser_module` 深度解析与实战
  • 螺杆支撑座:数控机床高效稳定运行的关键支撑
  • MYSQL的DDL语言和单表查询
  • 完全免费的PDF电子发票批量辅助打印工具
  • vue3+ts继续学习
  • js var a=如果ForRemove=true,是“normal“,否则为“bold“
  • 2025-05-06 事业-独立开发项目-记录
  • 软件代码签名证书SSL如何选择?
  • C++复习2
  • Spring Boot之MCP Client开发全介绍
  • 二叉树—中序遍历—非递归
  • 两数之和(暴力+哈希查找)
  • Linux[Makefile]
  • ffmpeg录音测试
  • 爬虫程序中如何添加异常处理?
  • Vi/Vim 编辑器详细指南
  • Facebook如何运用AI实现元宇宙的无限可能?
  • DC-DC降压型开关电源(Buck Converter)设计中,开关频率(f sw​ )、滤波电感(L)和滤波电容(C out​ )的关系和取舍
  • uniapp 全局混入:监听路由变化,路由变化即执行
  • 嵌入式openharmony标准鸿蒙系统驱动开发基本原理与流程
  • openssl 生成自签名证书实现接口支持https
  • 【coze】手册小助手(提示词、知识库、交互、发布)
  • C++中指针使用详解(4)指针的高级应用汇总
  • 人工智能对人类的影响
  • 【Hive入门】Hive安全管理与权限控制:审计日志全解析,构建完善的操作追踪体系
  • kubeadm部署k8s