当前位置: 首页 > java >正文

Day 40:训练和测试的规范写法

1. 单通道图片的规范准备工作

首先,回顾一下基础设置。用 MNIST 数据集(单通道灰度图像)作为例子。代码开头导入必要的库,并设置设备和随机种子,确保结果可复现。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

这里,transforms.Compose 用于数据预处理:转为张量、归一化(MNIST 的均值和标准差是固定的)。然后加载数据集,分成训练集和测试集,用 DataLoader 打包成批次(batch_size=64)。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

模型定义是个简单的 MLP:展平图像、两层全连接 + ReLU。损失函数用交叉熵,优化器选 Adam。

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()self.layer1 = nn.Linear(784, 128)self.relu = nn.ReLU()self.layer2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = self.layer1(x)x = self.relu(x)x = self.layer2(x)return xmodel = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

这些都是老生常谈,但规范写法从这里开始体现:数据和模型分离,便于后续扩展。

2. 训练函数的封装:逻辑隔离,参数复用

以前写鸢尾花 MLP 时,训练代码直接扔在主程序里,看起来乱糟糟的。现在,我们用函数 train 封装整个过程。为什么这么做?一是参数(如 epochs)调整方便,不用翻代码;二是复用性强,以后多模型对比时,直接调用就好。
函数里记录每个 batch 的损失(iteration 级),每 100 batch 打印一次,epoch 结束时测试并绘图。注意,早停策略暂不加(需要验证集),但测试函数独立封装。

def train(model, train_loader, test_loader, criterion, optimizer, device, epochs):model.train()all_iter_losses = []iter_indices = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totalepoch_test_loss, epoch_test_acc = test(model, test_loader, criterion, device)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)return epoch_test_acc

测试函数 test 也很干净:eval 模式、无梯度计算,计算平均损失和准确率。

def test(model, test_loader, criterion, device):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()avg_loss = test_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

绘图函数绘制 iteration 损失曲线,更直观地看训练过程。

def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()

3. 运行与效果分析

直接调用训练:epochs=2(实际可调大),看看效果。

epochs = 2
print("开始训练模型...")
final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, device, epochs)
print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")

运行后,会看到每 100 batch 的损失打印,epoch 结束的准确率,以及损失曲线图。MNIST 上 MLP 能到 96%+,但如果换 CIFAR-10,准确率可能卡在 50% 左右。为什么?MLP 忽略图像空间结构,全连接层参数爆炸,容易过拟合。

@浙大疏锦行

http://www.xdnf.cn/news/18328.html

相关文章:

  • Flink实现Exactly-Once语义的完整技术分解
  • 利用无事务方式插入数据库解决并发插入问题(最小主键id思路)
  • idea进阶技能掌握, 自带HTTP测试工具HTTP client使用方法详解,完全可替代PostMan
  • 暖哇科技AI调查智能体上线,引领保险调查风控智能化升级
  • 【数据结构】排序算法全解析:概念与接口
  • RK android14 Setting一级菜单IR遥控器无法聚焦问题解决方法
  • Apache ShenYu和Nacos之间的通信原理
  • VPS海外节点性能监控全攻略:从基础配置到高级优化
  • Android 入门到实战(三):ViewPager及ViewPager2多页面布局
  • 数据预处理学习心得:从理论到实践的桥梁搭建
  • 比剪映更轻量!SolveigMM 视频无损剪切实战体验
  • 29.Linux rsync+inotify解决同步数据实时性
  • 3D检测笔记:相机模型与坐标变换
  • 详解 scikit-learn 数据预处理工具:从理论到实践
  • CS+ for CC编译超慢的问题该如何解决
  • Day23 双向链表
  • 计算机网络--HTTP协议
  • 亚马逊新品爆单策略:从传统困境到智能突破
  • 【Grafana】grafana-image-renderer配合python脚本实现仪表盘导出pdf
  • 给你的Unity编辑器添加实现类似 Odin 的 条件显示字段 (ShowIf/HideIf) 功能
  • word——如何给封面、目录、摘要、正文设置不同的页码
  • 路由器NAT的类型测定
  • vue:vue中的ref和reactive
  • 【LLMs篇】18:基于EasyR1的Qwen2.5-VL GRPO训练
  • 层在init中只为创建线性层,forward的对线性层中间加非线性运算。且分层定义是为了把原本一长个代码的初始化和运算放到一个组合中。
  • 机械革命电竞控制台一直加载无法点击故障
  • MySQL事务及原理详解
  • 牛津大学xDeepMind 自然语言处理(3)
  • 工业电脑选得好生产效率节节高稳定可靠之选
  • C/C++ 与嵌入式岗位常见笔试题详解