深度学习——PyTorch保存模型与调用模型
神经网络保存模型与调用模型
在深度学习的实际应用中,模型训练通常需要大量的数据和计算资源。如果每次使用时都从头开始训练,不仅效率低下,还会浪费大量时间。因此,将训练好的模型保存下来,并在需要时直接调用,是非常重要的步骤。本文将详细介绍在 PyTorch 框架下如何保存模型与调用模型。
一、为什么要保存模型?
节省时间和计算成本:训练神经网络可能需要数小时甚至数天,保存模型可以避免重复训练。
迁移学习:保存预训练模型,方便后续在其他任务上进行微调。
模型部署:在推理阶段,通常只需要加载已经训练好的模型权重进行预测。
实验复现:保存模型可以帮助研究人员复现实验结果,保证实验的可重复性。
二、模型保存的方式
在 PyTorch 中,保存模型主要有两种方式:
1. 保存整个模型
torch.save(model, "model.pth")
model
:整个模型对象。"model.pth"
:保存的文件名(扩展名常用.pth
或.pt
)。
优点:简单,保存后直接可以加载使用。
缺点:依赖代码结构,如果模型类定义发生变化,可能无法加载。
2. 只保存模型参数(推荐)
torch.save(model.state_dict(), "model_params.pth")
model.state_dict()
:返回一个包含所有模型参数的字典。
优点:更灵活、通用,加载时只需保证模型结构一致即可。
缺点:加载时需要先重新定义模型结构,再加载参数。
三、模型调用(加载)
1. 加载整个模型
model = torch.load("model.pth")
model.eval() # 切换为推理模式
注意:调用时必须确保有相同的代码环境和依赖。
2. 加载模型参数
# 先定义模型结构
model = MyModelClass()
# 再加载参数
model.load_state_dict(torch.load("model_params.pth"))
model.eval()
推荐做法,因为这样更灵活,尤其在迁移学习或分布式训练时。
四、完整示例
下面给出一个小例子,演示如何训练、保存和调用模型。
import torch
import torch.nn as nn
import torch.optim as optim# ====== 定义模型 ======
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(10, 20)self.fc2 = nn.Linear(20, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return xmodel = SimpleNet()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# ====== 模拟训练 ======
x = torch.randn(100, 10)
y = torch.randn(100, 1)
for epoch in range(5):outputs = model(x)loss = criterion(outputs, y)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch [{epoch+1}/5], Loss: {loss.item():.4f}")# ====== 保存模型 ======
# 方法1:保存整个模型
torch.save(model, "whole_model.pth")# 方法2:保存参数(推荐)
torch.save(model.state_dict(), "model_params.pth")
解释
SimpleNet 继承自
nn.Module
,是一个最简单的全连接网络。nn.Linear(in_features, out_features)
:定义全连接层(权重矩阵 + 偏置)。ReLU 激活函数:非线性变换,提升模型表达能力。
forward(x):定义前向传播过程。
实例化模型,得到一个 SimpleNet对象
criterion = nn.MSELoss()
损失函数:均方误差(Mean Squared Error),常用于回归任务。
计算预测值与真实值之间的平方差。
optimizer = optim.SGD(model.parameters(), lr=0.01)
优化器:随机梯度下降(SGD)。
model.parameters()
表示需要优化的参数(网络的权重和偏置)。lr=0.01
:学习率,控制每次更新参数的幅度。
使用
torch.randn
生成随机数据,模拟训练集。x 的维度
(100, 10)
:100 条输入,每条是 10 个特征。y 的维度
(100, 1)
:100 条输出,每条是 1 个目标值。
前向传播:
model(x)
得到预测结果。计算损失:
criterion(outputs, y)
。梯度清零:PyTorch 的梯度是累积的,需要
optimizer.zero_grad()
。反向传播:
loss.backward()
自动计算每个参数的梯度。更新参数:
optimizer.step()
根据梯度更新权重。打印结果:显示当前
epoch
的损失值。
调用模型
# 方法1:加载整个模型
loaded_model = torch.load("whole_model.pth")
loaded_model.eval()# 方法2:加载模型参数
new_model = SimpleNet()
new_model.load_state_dict(torch.load("model_params.pth"))
new_model.eval()# 测试推理
test_input = torch.randn(1, 10)
print(new_model(test_input))
五、保存与调用中的注意事项
训练模式 vs 推理模式
在推理前调用
model.eval()
,关闭Dropout
、BatchNorm
的训练行为。在继续训练时调用
model.train()
。
保存优化器状态
如果需要在中断后继续训练,不仅要保存模型参数,还需要保存优化器状态:torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss }, "checkpoint.pth")
调用时:
checkpoint = torch.load("checkpoint.pth") model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] loss = checkpoint['loss']
文件扩展名
.pth
、.pt
都是常见的约定,不影响实际功能。跨设备加载
如果模型是在 GPU 上保存的,但在 CPU 上加载,需要指定:torch.load("model_params.pth", map_location=torch.device('cpu'))
六、总结
保存方式:保存整个模型(简单) vs 保存参数(推荐)。
调用方式:需注意模型结构定义和推理模式设置。
高级用法:保存和加载优化器状态,以便断点续训。
通过合理地保存和调用模型,可以显著提高实验效率,方便模型复现与部署。