当前位置: 首页 > ops >正文

深度学习篇---模型参数保存

在深度学习模型训练和部署过程中,模型保存是一个关键环节。不同框架在模型保存的实现上既有相似之处,也有各自的特点。下面详细介绍 PyTorch、TensorFlow 和 PaddlePaddle 中模型保存的代码及保存内容:

1. PyTorch

PyTorch 提供了灵活的模型保存方式,主要通过torch.save()函数实现,可保存模型结构、参数或训练状态。

(1)保存模型参数(推荐)

仅保存模型的参数(权重和偏置),不包含模型结构,文件体积较小。

import torch
import torch.nn as nn# 定义示例模型
class SimpleModel(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数(状态字典,state_dict)
torch.save(model.state_dict(), "model_params.pth")
  • 保存内容:模型的state_dict,是一个字典,层名称对应参数的张量
  • 用途:适用于训练中断后恢复训练,或在已知模型结构的情况下加载参数。
(2)保存完整模型

保存整个模型(包括结构和参数),但可能存在兼容性问题(如不同 PyTorch 版本或 Python 环境)。

# 保存完整模型
torch.save(model, "full_model.pth")
  • 保存内容:模型的类结构、参数及其他属性(如训练配置)。
  • 注意:不推荐用于跨环境部署,可能因类定义变化导致加载失败。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、epoch 等信息,用于中断后继续训练。

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
torch.save(checkpoint, "checkpoint.pth")
  • 保存内容:模型参数、优化器参数(如动量、学习率)、当前训练轮次、损失值等。

2. TensorFlow(Keras)

TensorFlow(尤其是 Keras 接口)提供了多种模型保存方式,支持 SavedModel 格式(推荐)和 HDF5 格式。

(1)保存完整模型(SavedModel 格式,推荐)

SavedModel 是 TensorFlow 的标准格式,包含模型结构、参数、计算图等,兼容性强。

  • 保存内容
    • 模型结构(网络层、输入输出形状);
    • 所有参数(权重和偏置);
    • 训练配置(优化器、损失函数、 metrics);
    • 计算图(用于部署到 TensorFlow Serving、移动端等)。
  • 用途:模型部署、跨平台使用(如 TensorFlow Lite、TensorRT)。
(2)保存为 HDF5 格式

保存模型结构和参数到单一文件,适用于简单场景。

# 保存为HDF5格式
model.save("model.h5")
  • 保存内容:模型结构(JSON 格式)和参数(二进制),但不包含计算图细节。
  • 注意:对复杂模型(如自定义层、控制流)的兼容性较差。
(3)保存权重(仅参数)

仅保存模型参数,需已知模型结构才能加载。

# 保存权重
model.save_weights("model_weights.h5")
  • 保存内容:各层的权重张量,不包含模型结构。
(4)训练过程保存(Checkpoint)

通过ModelCheckpoint回调保存训练过程中的模型状态。

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="training_checkpoint",save_weights_only=False,  # 是否仅保存权重save_best_only=True,      # 仅保存性能最好的模型monitor="val_loss"        # 监控指标
)# 训练时使用回调
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint_callback])
  • 保存内容:根据配置,可保存完整模型或仅权重,支持按指标(如验证集损失)保存最优模型。

3. PaddlePaddle

PaddlePaddle 的模型保存逻辑与 PyTorch 类似,主要通过paddle.save()Model.save()实现。

(1)保存模型参数(推荐)

仅保存模型参数,需结合模型结构加载。

import paddle
from paddle.nn import Linear# 定义示例模型
class SimpleModel(paddle.nn.Layer):def __init__(self):super().__init__()self.fc = Linear(in_features=10, out_features=2)def forward(self, x):return self.fc(x)model = SimpleModel()# 保存模型参数
paddle.save(model.state_dict(), "model_params.pdparams")
  • 保存内容:模型的state_dict,键为层名称,值为参数张量。
(2)保存完整模型

保存模型结构和参数,方便直接加载使用。

# 保存完整模型
paddle.Model(model).save("full_model")
  • 保存内容:模型结构(__model__文件)和参数(*.pdparams),支持跨环境加载。
(3)保存训练过程状态(断点续训)

保存模型参数、优化器状态、训练轮次等。

optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.001)
epoch = 10
loss = 0.123# 保存训练状态
checkpoint = {"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"epoch": epoch,"loss": loss
}
paddle.save(checkpoint, "checkpoint.pdparams")
  • 保存内容:模型参数、优化器参数(如学习率、动量)、训练进度等。

总结

框架保存类型核心函数 / 方法主要保存内容
PyTorch仅参数torch.save(model.state_dict(), ...)模型参数(state_dict)
完整模型torch.save(model, ...)模型结构 + 参数
训练状态(断点续训)torch.save(checkpoint_dict, ...)模型参数 + 优化器状态 + 训练进度
TensorFlow完整模型(推荐)model.save("saved_model")结构 + 参数 + 计算图 + 训练配置
HDF5 格式model.save("model.h5")结构 + 参数(兼容性有限)
仅参数model.save_weights(...)各层权重
训练过程检查点ModelCheckpoint回调按配置保存模型或权重(支持最优模型选择)
PaddlePaddle仅参数paddle.save(model.state_dict(), ...)模型参数(state_dict)
完整模型paddle.Model(model).save(...)结构 + 参数
训练状态(断点续训)paddle.save(checkpoint_dict, ...)模型参数 + 优化器状态 + 训练进度

实际应用中,仅保存参数通常是最灵活和高效的方式(需配合模型结构加载);完整模型适合快速部署但需注意兼容性;训练状态保存则用于中断后恢复训练。

http://www.xdnf.cn/news/18897.html

相关文章:

  • [肥用云计算] Serverless 多环境配置
  • PCM转音频
  • 面试之HashMap
  • LightRAG
  • 文档格式转换软件 一键Word转PDF
  • PPT处理控件Aspose.Slides教程:在 C# 中将 PPTX 转换为 Markdown
  • 【qml-7】qml与c++交互(自动补全提示)
  • [n8n] 全文检索(FTS)集成 | Mermaid图表生成
  • Android 使用MediaMuxer+MediaCodec编码MP4视频
  • 辅助驾驶出海、具身智能落地,稀缺的3D数据从哪里来?
  • 介绍智慧城管十大核心功能之一:风险预警系统
  • 架构评审:构建稳定、高效、可扩展的技术架构(下)
  • Java8-21的核心特性以及用法
  • 揭开.NET Core 中 ToList () 与 ToArray () 的面纱:从原理到抉择
  • ⸢ 贰 ⸥ ⤳ 安全架构:数字银行安全体系规划
  • 上海控安:GB 44495-2024《汽车整车信息安全技术要求》标准解读和测试方案
  • 修改win11任务栏时间字体和小图标颜色
  • vue实现表格轮播
  • 力扣18:四数之和
  • Python 实现冒泡排序:从原理到代码
  • PDFMathTranslate:让科学PDF翻译不再难——技术原理与实践指南
  • 2024中山大学研保研上机真题
  • (附源码)基于Spring Boot公务员考试信息管理系统设计与实现
  • 2025年渗透测试面试题总结-36(题目+回答)
  • 数据结构Java--8
  • Linux基础优化(Ubuntu、Kylin)
  • vue2实现背景颜色渐变
  • Java基础 8.27
  • 神经网络|(十六)概率论基础知识-伽马函数·上
  • Linux系统性能优化全攻略:从CPU到网络的全方位监控与诊断