当前位置: 首页 > java >正文

pytorch学习笔记-模型的保存与加载(自定义模型、网络模型)

博主最近勤奋更新的原因是一来之前攒了一些囤的,二来是终于要学完了一鼓作气啊啊啊啊

这一节写一下模型保存&加载,推荐方式2,方式1了解一下就ok

现有的网络模型的保存与加载

先要引入现有的网络模型

import torch
import torchvision
from torch import nnvgg16 = torchvision.models.vgg16(weights=None)

保存方式1&加载方式1

保存方式:
#保存方式1
torch.save(vgg16,"vgg16_method1.pth")
加载方式:

注意这里不写weights_only会提示:
(1) In PyTorch 2.6, we changed the default value of the weights_only argument in torch.load from False to True. Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.懒得翻译了大概看一眼吧
(2)xxxxx…就是建议你用方式2写

#加载方式1
model = torch.load("vgg16_method1.pth",weights_only=False)
# print(model)

保存方式2&加载方式2

保存方式:
#方式2,以字典形式存储
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
加载方式:

注意点就是要先定义一个模型,然后再把参数导入到模型中

#方式2存储加载
#要先定义一个模型,然后再把参数导入到模型中
vgg16 = torchvision.models.vgg16(weights=None)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))

自定义网络模型的保存与加载

假设你在model_save.py文件中定义了这样一个model:

class MyModule(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3,64,kernel_size=3)def forward(self,x):x = self.conv1(x)return xmy_module = MyModule()

保存方式1&加载方式1

保存方式:
#保存方式1
torch.save(my_module,"my_module_method1.pth")
加载方式:

注意一下就是如果你在model_load.py中如果没有对应的网络结构,会加载失败,因此需要引入自定义的模型,不用实例化,可选方式有1.引入包(推荐)2.把网络结构复制过来

#加载自定义的模型
#要求引入自定义的模型,不用实例化
#可选方式可以引入包(推荐)或者把网络结构复制过来
model = torch.load("my_module_method1.pth",weights_only=False)
# print(model)

保存方式2&加载方式2

保存方式:
#自定义模型存储2
torch.save(my_module.state_dict(),"my_module_method2.pth")
加载方式:

和加载现有的网络模型差不多,都是要先定义一个模型,然后再把参数导入到模型中

model2 = MyModule()
model2.load_state_dict(torch.load("my_module_method2.pth"))
print(model2)
http://www.xdnf.cn/news/17937.html

相关文章:

  • 【fwk基础】repo sync报错后如何快速修改更新
  • 图片滤镜处理(filters)
  • 戴永红×数图:重构零售空间价值,让陈列创造效益!
  • 机器翻译:模型微调(Fine-tuning)与调优详解
  • Comfyui进入python虚拟环境
  • 大数据系列之:设置CMS垃圾收集器
  • 如何在 Ubuntu 24.04 Noble LTS 上安装 Apache 服务器
  • 龙虎榜——20250815
  • 【网络】IP总结复盘
  • IDEA 清除 ctrl+shift+r 全局搜索记录
  • SAP ALV导出excel 报 XML 错误的 /xl/sharedStrings.xml
  • STM32在使用DMA发送和接收时的模式区别
  • 数据处理分析环境搭建+Numpy使用教程
  • 主流开源实时互动数字人大模型
  • 易道博识康铁钢:大小模型深度融合是现阶段OCR的最佳解决方案
  • imx6ull-驱动开发篇25——Linux 中断上半部/下半部
  • 超级云 APP 模式:重构移动互联网生态的新引擎
  • Radar Forward-Looking Imaging Based on Chirp Beam Scanning论文阅读
  • 列式存储与行式存储:核心区别、优缺点及代表数据库
  • 代码随想录Day51:图论(岛屿数量 深搜广搜、岛屿的最大面积)
  • 第七十二章: AI训练的“新手村”指南:小规模链路构建与调参技巧——从零开始,驯服你的模型!
  • Java面试实战系列【并发篇】- Semaphore深度解析与实战
  • gnu arm toolchain中的arm-none-eabi-gdb.exe的使用方法?
  • 【C#补全计划】委托
  • uniapp 开发微信小程序,获取经纬度并且转化详细地址(单独封装版本)
  • 零基础-动手学深度学习-10.4. Bahdanau 注意力
  • 电脑上练打字用什么软件最好:10款打字软件评测
  • 【学习笔记】Java并发编程的艺术——第10章 Executor框架
  • VUE3 学习笔记2 computed、watch、生命周期、hooks、其他组合式API
  • RecyclerView 性能优化:从原理到实践的深度优化方案