当前位置: 首页 > backend >正文

深度学习——PyTorch保存模型与调用模型


神经网络保存模型与调用模型

在深度学习的实际应用中,模型训练通常需要大量的数据和计算资源。如果每次使用时都从头开始训练,不仅效率低下,还会浪费大量时间。因此,将训练好的模型保存下来,并在需要时直接调用,是非常重要的步骤。本文将详细介绍在 PyTorch 框架下如何保存模型与调用模型。


一、为什么要保存模型?

  1. 节省时间和计算成本:训练神经网络可能需要数小时甚至数天,保存模型可以避免重复训练。

  2. 迁移学习:保存预训练模型,方便后续在其他任务上进行微调。

  3. 模型部署:在推理阶段,通常只需要加载已经训练好的模型权重进行预测。

  4. 实验复现:保存模型可以帮助研究人员复现实验结果,保证实验的可重复性。


二、模型保存的方式

在 PyTorch 中,保存模型主要有两种方式:

1. 保存整个模型

torch.save(model, "model.pth")
  • model:整个模型对象。

  • "model.pth":保存的文件名(扩展名常用 .pth.pt)。

优点:简单,保存后直接可以加载使用。
缺点:依赖代码结构,如果模型类定义发生变化,可能无法加载。

2. 只保存模型参数(推荐)

torch.save(model.state_dict(), "model_params.pth")
  • model.state_dict():返回一个包含所有模型参数的字典。

优点:更灵活、通用,加载时只需保证模型结构一致即可。
缺点:加载时需要先重新定义模型结构,再加载参数。


三、模型调用(加载)

1. 加载整个模型

model = torch.load("model.pth")
model.eval()  # 切换为推理模式

注意:调用时必须确保有相同的代码环境和依赖。

2. 加载模型参数

# 先定义模型结构
model = MyModelClass()
# 再加载参数
model.load_state_dict(torch.load("model_params.pth"))
model.eval()

推荐做法,因为这样更灵活,尤其在迁移学习或分布式训练时。


四、完整示例

下面给出一个小例子,演示如何训练、保存和调用模型。

import torch
import torch.nn as nn
import torch.optim as optim# ====== 定义模型 ======
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# ====== 模拟训练 ======
x = torch.randn(100, 10)
y = torch.randn(100, 1)
for epoch in range(5):outputs = model(x)loss = criterion(outputs, y)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch [{epoch+1}/5], Loss: {loss.item():.4f}")# ====== 保存模型 ======
# 方法1:保存整个模型
torch.save(model, "whole_model.pth")# 方法2:保存参数(推荐)
torch.save(model.state_dict(), "model_params.pth")

解释

  • SimpleNet 继承自 nn.Module,是一个最简单的全连接网络。

  • nn.Linear(in_features, out_features):定义全连接层(权重矩阵 + 偏置)。

  • ReLU 激活函数:非线性变换,提升模型表达能力。

  • forward(x):定义前向传播过程。

  • 实例化模型,得到一个 SimpleNet对象

  • criterion = nn.MSELoss()

    • 损失函数:均方误差(Mean Squared Error),常用于回归任务。

    • 计算预测值与真实值之间的平方差。

  • optimizer = optim.SGD(model.parameters(), lr=0.01)

    • 优化器:随机梯度下降(SGD)。

    • model.parameters() 表示需要优化的参数(网络的权重和偏置)。

    • lr=0.01:学习率,控制每次更新参数的幅度。

  • 使用 torch.randn 生成随机数据,模拟训练集。

    • x 的维度 (100, 10):100 条输入,每条是 10 个特征。

    • y 的维度 (100, 1):100 条输出,每条是 1 个目标值。

  • 前向传播model(x) 得到预测结果。

  • 计算损失criterion(outputs, y)

  • 梯度清零:PyTorch 的梯度是累积的,需要 optimizer.zero_grad()

  • 反向传播loss.backward() 自动计算每个参数的梯度。

  • 更新参数optimizer.step() 根据梯度更新权重。

  • 打印结果:显示当前 epoch 的损失值。

调用模型

# 方法1:加载整个模型
loaded_model = torch.load("whole_model.pth")
loaded_model.eval()# 方法2:加载模型参数
new_model = SimpleNet()
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()# 测试推理
test_input = torch.randn(1, 10)
print(new_model(test_input))

五、保存与调用中的注意事项

  1. 训练模式 vs 推理模式

    • 在推理前调用 model.eval(),关闭 DropoutBatchNorm 的训练行为。

    • 在继续训练时调用 model.train()

  2. 保存优化器状态
    如果需要在中断后继续训练,不仅要保存模型参数,还需要保存优化器状态:

    torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss
    }, "checkpoint.pth")
    

    调用时:

    checkpoint = torch.load("checkpoint.pth")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
  3. 文件扩展名
    .pth.pt 都是常见的约定,不影响实际功能。

  4. 跨设备加载
    如果模型是在 GPU 上保存的,但在 CPU 上加载,需要指定:

    torch.load("model_params.pth", map_location=torch.device('cpu'))
    

六、总结

  • 保存方式:保存整个模型(简单) vs 保存参数(推荐)。

  • 调用方式:需注意模型结构定义和推理模式设置。

  • 高级用法:保存和加载优化器状态,以便断点续训。

通过合理地保存和调用模型,可以显著提高实验效率,方便模型复现与部署。

http://www.xdnf.cn/news/20245.html

相关文章:

  • JUC之并发编程
  • MyBatis入门到精通:CRUD实战指南
  • 使用UniApp实现下拉框和表格组件页面
  • Android Kotlin 动态注册 Broadcast 的完整封装方案
  • uv教程 虚拟环境
  • kotlin - 2个Fragment实现左右显示,左边列表,右边详情,平板横、竖屏切换
  • 【LeetCode 每日一题】2348. 全 0 子数组的数目
  • 开源OpenHarmony润开鸿HH-SCDAYU800A开发板开箱体验
  • AI热点周报(8.31~9.6): Qwen3‑Max‑Preview上线、GLM-4.5提供一键迁移、Gemini for Home,AI风向何在?
  • C++进阶——继承(2)
  • 基于STM32的交通灯设计—紧急模式、可调时间
  • 如何理解`(line_status = parse_line()) == LINE_OK`?
  • @Autowired注解(二)
  • 【CAN通信】AUTOSAR架构下TC3xx芯片是如何将一帧CAN报文接收上来的
  • Xsens解码人形机器人训练的语言
  • 如何通过AI进行数据资产梳理
  • 43这周打卡——生成手势图像 (可控制生成)
  • 球坐标系下调和函数的构造:多项式边界条件的求解方法
  • linux Nginx服务配置介绍,和配置流程
  • 快手Keye-VL 1.5开源128K上下文+0.1秒级视频定位+跨模态推理,引领视频理解新标杆
  • 错误是ModuleNotFoundError: No module named ‘pip‘解决“找不到 pip”
  • vsan default storage policy 具体是什么策略?
  • HTB GoodGames
  • centos下gdb调试python的core文件
  • 串口通信的学习
  • 日内5%,总回撤10%:EagleTrader风控规则里,隐藏着什么核心考点?
  • 使用API接口获取淘宝商品详情数据需要注意哪些风险?
  • MySQL数据库精研之旅第十六期:深度拆解事务核心(上)
  • python + Flask模块学习 1 基础用法
  • IC ATE集成电路测试学习——Stuck-at fault And Chain(一)