当前位置: 首页 > ds >正文

第三十一篇 AI的“能力考”:模型评估、保存与加载的艺术【总结前面3】

模型保存到推理

  • 前言:从“学习”到“应用”的最后一步
  • 第一章:AI的“成绩单”——模型评估
    • 1.1 评估模式:model.eval()与torch.no_grad()的智慧
    • 1.2 分类任务的“及格线”:准确率(Accuracy)
    • 1.3 亲手评估你的AI模型准确率
  • 第二章:AI“智慧”的持久化——模型保存
    • 2.1 PyTorch的“传统艺能”:torch.save与state_dict
    • 2.2 新王登基:.safetensors的安全与高效
    • 2.3 将你的AI模型安全“存档”
  • 第三章:AI的“重生”与“实战”——模型加载与推理
    • 3.1 加载模型:让AI的“智慧”重新焕发生机
    • 3.2 加载模型并进行单张图片推理
  • 总结与展望:你已拥有AI的“入门级驾驶证”

前言:从“学习”到“应用”的最后一步

在上一章,我们亲手搭建并驱动了一个AI的“思考引擎”,它正在努力学习识别手写数字。但模型学得怎么样?它的“识字”能力达到了什么水平?更重要的是,辛辛苦苦训练出的这个AI“大脑”,如何才能保存下来,以便未来可以投入实际应用,或者分享给他人呢?
使用模型注意事项

这些问题,正是我们今天将要解决的。本章将带领你完成AI学习的最后闭环:模型评估(检验成果)、模型保存(持久化智慧)和模型加载与推理(让智慧重焕光彩,投入实战)。

第一章:AI的“成绩单”——模型评估

在训练过程中,我们关注损失的下降。但损失下降并不代表模型真的“学得好”,我们还需要用独立的测试集来评估它的泛化能力
AI成绩单

1.1 评估模式:model.eval()与torch.no_grad()的智慧

在评估模型时,我们必须切换到评估模式。这是为了:

model.eval():告诉模型现在是评估阶段,关闭nn.Dropout(防止随机丢弃神经元)和nn.BatchNorm(停止更新均值和方差,使用训练时的统计量)等层。这确保了模型在评估时的行为是确定的、可重复的。

torch.no_grad():创建一个上下文管理器,告诉PyTorch在这个代码块内部,不要计算梯度。
为什么? 评估阶段我们不需要进行反向传播来更新参数,计算梯度是多余的,这会浪费计算资源和内存。

好处:加速评估过程,减少内存占用。

1.2 分类任务的“及格线”:准确率(Accuracy)

对于分类任务,最直观、最常用的评估指标就是准确率(Accuracy)。

准确率 = (正确预测的样本数 / 总样本数) * 100%

1.3 亲手评估你的AI模型准确率

目标:对上一章训练好的MNIST分类器进行准确率评估,并可视化部分预测结果。

前置:你需要确保上一章simple_mnist_classifier_full.py已经运行过,并且其训练好的模型权重文件simple_mnist_mlp.pth已经保存在mnist_results/目录下

代码展示

# case_10_3_model_evaluation.pyimport torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os# --- 0. 定义模型结构和加载权重 (与训练时保持一致) ---
# 这部分代码必须和训练模型的 SimpleMLPC Classifier 类定义完全相同
class SimpleMLPClassifier(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(SimpleMLPClassifier, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# --- 加载训练时定义的超参数 ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = 'mnist_results/simple_mnist_mlp.pth' # 上一章保存的模型权重路径def evaluate_model(model, test_loader, device):model.eval() # 设置模型为评估模式correct = 0total = 0# 用于记录错误预测,以便可视化wrong_predictions = [] with torch.no_grad(): # 在评估时,禁用梯度计算for data, target in test_loader:data, target = data.to(device), target.to(device) # 将数据移动到设备outputs = model(data) # 前向传播,获取模型预测的Logits# 从Logits中找到预测概率最高的类别# torch.max(outputs.data, 1) 返回每一行最大值及其索引。1表示在维度1上求最大值_, predicted = torch.max(outputs.data, 1) # _ 是最大值,我们只关心索引(predicted类别)total += target.size(0) # 累加当前批次的样本总数# 比较预测类别和真实类别,并累加正确预测的数量correct += (predicted == target).sum().item() # 记录错误的预测 (用于可视化)incorrect_mask = (predicted != target)for i in range(len(incorrect_mask)):if incorrect_mask[i]:wrong_predictions.append((data[i].cpu(), target[i].cpu(), predicted[i].cpu()))accuracy = 100 * correct / totalprint(f'在 {total} 张测试图片上的准确率: {accuracy:.2f}%')return accuracy, wrong_predictionsdef visualize_predictions(test_loader, model, device, num_display=10):model.eval()# 随机选择一个Batch用于可视化data_iter = iter(test_loader)data, labels = next(data_iter)data, labels = data.to(device), labels.to(device)with torch.no_grad():outputs = model(data)_, predicted = torch.max(outputs.data, 1)plt.figure(figsize=(12, 6))plt.suptitle("模型预测结果示例 (绿色为正确,红色为错误)", fontsize=16)for i in range(num_display):plt.subplot(2, 5, i + 1)# .squeeze() 移除单维度通道 (例如 [1, 28, 28] -> [28, 28])plt.imshow(data[i].cpu().squeeze(), cmap='gray')is_correct = (predicted[i] == labels[i]).item()color = 'green' if is_correct else 'red'plt.title(f"Pred: {predicted[i].item()}\nTrue: {labels[i].item()}", color=color)plt.axis('off')plt.tight_layout(rect=[0, 0.03, 1, 0.95])plt.savefig('mnist_results/sample_predictions.png')plt.show()# --- 主执行流程 ---
if __name__ == '__main__':# 确保mnist_results目录存在os.makedirs('mnist_results', exist_ok=True)# 加载测试数据集transform = transforms.ToTensor()test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False) # 批次大小可以设小一些# 实例化模型model_for_eval = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)# 加载预训练权重 (这是关键!确保你有这个文件)try:model_for_eval.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))print(f"✅ 模型权重已从 '{MODEL_PATH}' 成功加载。")except FileNotFoundError:print(f"❌ 错误:未找到模型权重文件 '{MODEL_PATH}'。请先运行上一章代码训练并保存模型!")exit() # 如果模型文件不存在,无法继续print("\n--- 开始评估模型 ---")accuracy, _ = evaluate_model(model_for_eval, test_loader, DEVICE)print(f"\n模型最终准确率为: {accuracy:.2f}%")print("\n--- 可视化部分预测结果 ---")visualize_predictions(test_loader, model_for_eval, DEVICE, num_display=10)print(f"部分预测结果图已保存到: mnist_results/sample_predictions.png")

代码解读与见证奇迹】

运行这段代码,你将看到模型在MNIST测试集上的准确率,通常会达到90%以上。
可视化部分,你会看到一系列原始的数字图片,上面标注着模型的预测结果和真实标签。正确预测的标题是绿色的,错误预测是红色的,让你直观感受到AI的“识字”能力!

这证明了我们亲手搭建的MLP模型,经过简单的训练,已经具备了对未见过数据进行识别的泛化能力。

第二章:AI“智慧”的持久化——模型保存

学习如何将辛苦训练好的AI模型“存档”到硬盘,并深入对比torch.save(基于Pickle)和safetensors两种主流保存方式的安全与效率差异。
模型持久化

2.1 PyTorch的“传统艺能”:torch.save与state_dict

这是PyTorch原生的保存方法。最推荐的做法是只保存模型的state_dict
model.state_dict():它返回一个Python字典,包含了模型所有可学习参数(权重和偏置)的副本。这就像模型的“灵魂”或“基因组”。

优点:文件小,只包含数据,不包含代码逻辑,加载时需明确模型结构,相对安全。

缺点:底层使用Python的pickle协议,这会带来潜在的安全风险

2.2 新王登基:.safetensors的安全与高效

为了解决pickle的安全问题,Hugging Face社区推出了**.safetensors**格式。

核心优势:

  1. 绝对安全:只存储Tensor的原始二进制数据和JSON格式的元数据(形状、类型),不包含任何可

2.执行代码。safetensors.torch.load_file()不会执行任意代码。

  1. 加载极快:特别是在部分加载(分片)或跨语言加载时,其效率远超pickle。

  2. 跨平台/框架:设计上就考虑了不同AI框架(PyTorch, TensorFlow, JAX)的兼容性。
    推荐:在分享模型或加载未知来源的模型时,优先使用.safetensors格式。

2.3 将你的AI模型安全“存档”

目标:使用torch.save和safetensors两种方法,保存我们训练好的MLP分类器的权重。

前置:假设你已经运行了上一章的代码,并且训练好的模型实例trained_model可用。

# case_10_3_model_saving.pyimport torch
import torch.nn as nn
import os
# 需要安装safetensors库: pip install safetensors
from safetensors.torch import save_file, load_file# --- 0. 准备工作:定义模型结构 (同评估时一致) ---
class SimpleMLPClassifier(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(SimpleMLPClassifier, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# --- 导入训练时定义的超参数 (或直接定义) ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = 'mnist_results' # 结果目录
os.makedirs(MODEL_DIR, exist_ok=True)# 假设我们有一个已经训练好的模型实例
# 在实际运行中,你可以从上一章的 main_training_loop 返回 trained_model
# 这里我们为了独立运行,先实例化并模拟加载权重
model_to_save = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
# 模拟加载一个随机权重,表示它是“训练好的”
model_to_save.load_state_dict(torch.load('mnist_results/simple_mnist_mlp.pth', map_location=DEVICE))
model_to_save.eval()
print("模型实例已准备好进行保存。")# --- 1. 策略一:使用 torch.save 只保存 State Dict (.pth/.pt) ---
print("\n--- 策略一:使用 torch.save 保存模型权重 ---")
pytorch_save_path = os.path.join(MODEL_DIR, 'simple_mnist_mlp_weights_torch.pth')
# model.state_dict() 获取模型的权重字典
torch.save(model_to_save.state_dict(), pytorch_save_path)
print(f"✅ 模型权重已用 torch.save 保存到: {pytorch_save_path}")# --- 2. 策略二:使用 safetensors 保存 State Dict (.safetensors) ---
print("\n--- 策略二:使用 safetensors 安全保存模型权重 ---")
safetensors_save_path = os.path.join(MODEL_DIR, 'simple_mnist_mlp.safetensors')
# safetensors.torch.save_file() 函数更推荐
save_file(model_to_save.state_dict(), safetensors_save_path)
print(f"✅ 模型权重已用 safetensors 安全保存到: {safetensors_save_path}")print("\n模型已成功保存到两种文件格式中!")

代码解读与安全警示】
运行这段代码,你会看到两个文件被创建:.pth和.safetensors。它们都包含了模型的权重信息,但背后的安全性却天差地别。
⚠️ 再次提醒: torch.save()(使用Pickle协议)在加载不信任来源的文件时存在任意代码执行的风险。safetensors则从设计上规避了这一风险,是更安全的选择。

第三章:AI的“重生”与“实战”——模型加载与推理

学习如何将保存的模型加载回内存,并用它来对新的、单张图片进行推理预测。
模型重生

3.1 加载模型:让AI的“智慧”重新焕发生机

模型加载,就是将硬盘上的模型文件重新读取到内存中,再次实例化为PyTorch模型对象的过程。

加载state_dict: 始终是首选。你需要先定义模型的类结构(即SimpleMLPClassifier),然后使用

model.load_state_dict(torch.load(…))来加载权重。

map_location: 在torch.load()时,map_location参数非常有用。它可以指定将模型加载到CPU或GPU,尤其当你在GPU上训练但在CPU上推理时。

3.2 加载模型并进行单张图片推理

目标:加载我们之前保存的.safetensors模型,并用它来预测一张新的手写数字图片。
前置:确保simple_mnist_mlp.safetensors文件已存在。

# case_10_3_model_loading_inference.pyimport torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
# 需要安装safetensors库: pip install safetensors
from safetensors.torch import save_file, load_file # 导入load_file函数# --- 0. 准备工作:定义模型结构 (必须和训练/保存时完全一致) ---
class SimpleMLPClassifier(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(SimpleMLPClassifier, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = x.view(-1, input_dim) # 注意这里的 input_dim 变量需要传入或定义out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# --- 定义模型超参数和文件路径 ---
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
OUTPUT_DIM = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH_SAFETENSORS = 'mnist_results/simple_mnist_mlp.safetensors' # 使用safetensors保存的文件# --- 1. 加载模型 ---
print("--- 1. 加载模型 ---")
# 实例化一个新的模型“躯体”
loaded_model = SimpleMLPClassifier(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)# 加载保存的权重文件 (.safetensors)
try:# load_file 返回一个state_dictloaded_state_dict = load_file(MODEL_PATH_SAFETENSORS)loaded_model.load_state_dict(loaded_state_dict)print(f"✅ 模型权重已从 '{MODEL_PATH_SAFETENSORS}' 成功加载。")
except FileNotFoundError:print(f"❌ 错误:未找到模型权重文件 '{MODEL_PATH_SAFETENSORS}'。请先运行上一章代码训练并保存模型!")exit()
loaded_model.eval() # 设置为评估模式,非常重要!
print("模型已准备好进行推理。")# --- 2. 准备一张新的测试图片进行推理 ---
print("\n--- 2. 准备新的测试图片进行推理 ---")
# 从MNIST测试集中随机取一张图片
transform = transforms.ToTensor()
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
# 随机选择一个索引
random_idx = np.random.randint(0, len(test_dataset))
sample_image, true_label = test_dataset[random_idx]print(f"随机选择的图片索引: {random_idx}")
print(f"真实标签: {true_label}")# 将图片添加到Batch维度,并移动到设备
# 从 [1, 28, 28] -> [1, 1, 28, 28] (增加Batch维度)
input_for_inference = sample_image.unsqueeze(0).to(DEVICE) 
print(f"用于推理的图片形状: {input_for_inference.shape}")# --- 3. 进行推理预测 ---
print("\n--- 3. 模型进行推理预测 ---")
with torch.no_grad(): # 推理时禁用梯度计算output_logits = loaded_model(input_for_inference)# 获取概率分布 (可选,CrossEntropyLoss内部已做)probabilities = F.softmax(output_logits, dim=1)# 获取预测类别_, predicted_class = torch.max(output_logits, 1)predicted_class_item = predicted_class.item() # 从Tensor中提取预测类别数值
predicted_prob = probabilities[0, predicted_class_item].item() # 获取预测类别的概率print(f"模型预测类别: {predicted_class_item}")
print(f"预测概率: {predicted_prob*100:.2f}%")# --- 4. 可视化结果 ---
plt.figure(figsize=(4, 4))
plt.imshow(sample_image.squeeze().numpy(), cmap='gray') # 显示图片
plt.title(f"预测: {predicted_class_item} (真实: {true_label})", color='green' if predicted_class_item == true_label else 'red', fontsize=14)
plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(MODEL_DIR, f'inference_result_{random_idx}.png'))
plt.show()print("\n🎉 模型推理完成,结果已可视化!")`在这里插入代码片`

代码解读与见证奇迹】
运行这段代码,你会看到:
模型权重从.safetensors文件被成功加载。
随机从测试集中选择一张图片,并显示出来。
模型对这张图片给出了准确的预测,并且显示了对应的概率。
这证明了你的AI模型,在经过训练、保存、加载之后,能够重新“复活”,并投入到实际的推理应用中。你已经掌握了AI模型从“实验室”走向“真实世界”的最后一步

总结与展望:你已拥有AI的“入门级驾驶证”

总结与展望:你已拥有AI的“入门级驾驶证”
恭喜你!今天你已经通过亲手编写和运行代码,完成了AI学习流程的最后闭环。
✨ 本章惊喜概括 ✨

你掌握了什么?对应的核心操作/概念
准确评估模型✅ model.eval(), torch.no_grad(), 准确率计算
AI“智慧”的持久化✅ torch.save和safetensors保存state_dict
AI的“重生”✅ 加载模型权重,map_location
AI的“实战应用”✅ 对单张图片进行推理预测
你现在不仅仅是“听说过”AI模型,你已经能够从零开始搭建、训练、评估、保存、加载、并使用一个完整的AI模型了! 你已经拥有了AI世界的“入门级驾驶证”,可以自信地开始探索更复杂的AI应用和架构。
http://www.xdnf.cn/news/16847.html

相关文章:

  • MBR与GPT分区表深度解析:硬盘分区该怎么选?
  • pip库版本升级
  • Android Studio 中Revert Commit、Undo Commit 和 Drop Commit 使用场景
  • Android Studio怎么显示多排table,打开文件多行显示文件名
  • 现在有哪些广泛使用的时序数据库?
  • [免费]基于Python的招聘职位信息推荐系统(猎聘网数据分析与可视化)(Django+requests库)【论文+源码+SQL脚本】
  • [mind-elixir]Mind-Elixir 的交互增强:单击、双击与鼠标 Hover 功能实现
  • Web3.0 和 Web2.0 生态系统比较分析:差异在哪里?
  • 【Datawhale AI夏令营】科大讯飞AI大赛(大模型技术)/夏令营:让AI理解列车排期表(Task3)
  • 【python 获取邮箱验证码】模拟登录并获取163邮箱验证码,仅供学习!仅供测试!仅供交流!
  • uni-app webview的message监听不生效(uni.postmessage is not a function)
  • linux 执行sh脚本,提示$‘\r‘: command not found
  • 从一开始的网络攻防(十四):WAF绕过
  • day21-Excel文件解析
  • 【MySQL 数据库】MySQL索引特性(一)磁盘存储定位扇区InnoDB页
  • AI 代码助手在大前端项目中的协作开发模式探索
  • C++ Qt网络编程实战:跨平台TCP调试工具开发
  • 容器与虚拟机的本质差异:从资源隔离到网络存储机制
  • 2020 年 NOI 最后一题题解
  • Apple基础(Xcode②-Flutter结构解析)
  • 【硬件-笔试面试题】硬件/电子工程师,笔试面试题-49,(知识点:OSI模型,物理层、数据链路层、网络层)
  • 2025年湖北中级注册安全工程师报考那些事
  • 网络安全学习第16集(cdn知识点)
  • 知识速查大全:python面向对象基础
  • C++从入门到起飞之——智能指针!
  • 电子电气架构 --- 区域架构让未来汽车成为现实
  • 深入理解PostgreSQL的MVCC机制
  • SpringBoot之多环境配置全解析
  • Linux 系统日志管理与时钟同步实用指南
  • Tlias 案例-整体布局(前端)