当前位置: 首页 > ai >正文

残差神经网络的案例

项目任务:运用残差网络模型来识别手写数字

代码实现:

import torch
print(torch.__version__)
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor())train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)for X,y in test_dataloader:print(f"Shape of X[N,C,H,W]:{X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")breakdevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")import torch
import torch.nn as nn
import torch.nn.functional as F# 残差块定义
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__()self.conv1 = torch.nn.Conv2d(channels_in,30, kernel_size=5, padding=2)self.conv2 = torch.nn.Conv2d(30, channels_in, kernel_size=3, padding=1)def forward(self, x):out = self.conv1(x)out = self.conv2(out)return F.relu(out + x)  # 残差连接(out + 输入x)# ResNet网络定义
class ResNet(nn.Module):def __init__(self):super().__init__()self.conv1 = torch.nn.Conv2d(1,20,5)self.conv2 = torch.nn.Conv2d(20,15,3)self.maxpool = torch.nn.MaxPool2d(2)self.resblock1 = ResBlock(channels_in=20)  # 第一个残差块self.resblock2 = ResBlock(channels_in=15)  # 第二个残差块self.full_c = torch.nn.Linear(375, 10)     # 全连接层(输出维度10,对应10分类)def forward(self, x):size = x.shape[0]  # 获取批次大小# 第一段卷积+池化+残差块x = F.relu(self.maxpool(self.conv1(x)))x = self.resblock1(x)# 第二段卷积+池化+残差块x = F.relu(self.maxpool(self.conv2(x)))x = self.resblock2(x)# 展平后送入全连接层x = x.view(size, -1)x = self.full_c(x)return xmodel = ResNet().to(device)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 ==0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num +=1
best_acc=0
def test(dataloader,model,loss_fn):global best_accsize = len(dataloader.dataset)num_batches= len(dataloader)model.eval()test_loss = 0correct = 0with torch.no_grad():for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_pj_loss = test_loss / num_batchestest_acy = correct / size * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuray: {test_acy:>5.2f}%")if correct > best_acc:best_acc = correct# 保存模型的状态字典,而非整个模型torch.save(model.state_dict(), 'best.pth')  # 重点修改这里print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")else:print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
i=10
for j in range(i):print(f"Epoch {j+1}\n----------")train(train_dataloader, model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)

这段代码是一个基于 PyTorch 实现的残差网络(ResNet),用于训练和测试 MNIST 手写数字识别任务。下面对代码进行解析:

1. 库导入与数据集准备

import torch
print(torch.__version__)  # 打印PyTorch版本
from torch import nn  # 神经网络模块
from torch.utils.data import DataLoader  # 数据加载工具
from torchvision import datasets  # 计算机视觉数据集
from torchvision.transforms import ToTensor  # 图像转张量的转换工具

导入了 PyTorch 核心库、神经网络模块、数据加载工具,以及处理 MNIST 数据集的相关工具。

ToTensor()用于将图像(PIL 格式)转换为 PyTorch 张量,并自动将像素值归一化到[0, 1]范围。

# 加载MNIST训练集和测试集
training_data = datasets.MNIST(root='data',  # 数据保存路径train=True,   # 训练集download=True,  # 若本地无数据则自动下载transform=ToTensor()  # 应用转换
)
test_data = datasets.MNIST(root='data',train=False,  # 测试集download=True,transform=ToTensor()
)

MNIST 是经典的手写数字数据集,包含 60000 张训练图和 10000 张测试图,每张图是 28x28 的灰度图(单通道),标签为 0-9 的数字。

# 数据加载器(按批次加载数据,方便批量训练)
train_dataloader = DataLoader(training_data, batch_size=64)  # 训练集批次大小64
test_dataloader = DataLoader(test_data, batch_size=64)  # 测试集批次大小64

DataLoader将数据集按batch_size分批,支持自动打乱数据、多线程加载等功能,是训练中高效读取数据的工具。

# 打印数据形状,验证数据格式
for X, y in test_dataloader:print(f"Shape of X[N,C,H,W]: {X.shape}")  # 输入图像形状print(f"Shape of y: {y.shape} {y.dtype}")  # 标签形状和类型break

输出示例:X[N,C,H,W]中,N=64(批次大小)、C=1(单通道灰度图)、H=28W=28(图像尺寸);y是长度为 64 的标签(类型为long,适合分类任务)。

2. 设备配置

# 自动选择训练设备(优先GPU,其次苹果芯片,最后CPU)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")

深度学习训练通常需要 GPU 加速,这里自动检测并选择可用的加速设备,最大化训练效率。

3. 残差网络(ResNet)定义

残差网络的核心是残差块(Residual Block),通过 “跳跃连接”(将输入直接加到输出)缓解深层网络的梯度消失问题。

3.1 残差块(ResBlock)
class ResBlock(nn.Module):def __init__(self, channels_in):super().__init__()# 第一个卷积层:输入通道数→30,5x5卷积核,padding=2(保持尺寸)self.conv1 = torch.nn.Conv2d(channels_in, 30, kernel_size=5, padding=2)# 第二个卷积层:30→输入通道数,3x3卷积核,padding=1(保持尺寸)self.conv2 = torch.nn.Conv2d(30, channels_in, kernel_size=3, padding=1)def forward(self, x):out = self.conv1(x)  # 第一次卷积out = self.conv2(out)  # 第二次卷积return F.relu(out + x)  # 残差连接(输出+输入)+ ReLU激活

残差块的关键是out + x:将输入x直接加到卷积后的输出out上,实现 “跳跃连接”,确保梯度能有效回传。

卷积层的padding设置保证了输入和输出的尺寸一致,才能进行加法操作。

3.2 完整 ResNet 网络
class ResNet(nn.Module):def __init__(self):super().__init__()# 第一层卷积:输入1通道(灰度图)→20通道,5x5卷积核self.conv1 = torch.nn.Conv2d(1, 20, 5)# 第二层卷积:20通道→15通道,3x3卷积核self.conv2 = torch.nn.Conv2d(20, 15, 3)self.maxpool = torch.nn.MaxPool2d(2)  # 2x2最大池化(尺寸减半)self.resblock1 = ResBlock(channels_in=20)  # 第一个残差块(输入20通道)self.resblock2 = ResBlock(channels_in=15)  # 第二个残差块(输入15通道)self.full_c = torch.nn.Linear(375, 10)  # 全连接层(输出10类,对应0-9)def forward(self, x):size = x.shape[0]  # 获取批次大小(用于后续展平操作)# 第一段:卷积→池化→激活→残差块x = F.relu(self.maxpool(self.conv1(x)))  # conv1→池化(尺寸减半)→ReLUx = self.resblock1(x)  # 经过第一个残差块# 第二段:卷积→池化→激活→残差块x = F.relu(self.maxpool(self.conv2(x)))  # conv2→池化(尺寸减半)→ReLUx = self.resblock2(x)  # 经过第二个残差块# 展平特征图→全连接层输出x = x.view(size, -1)  # 展平为(batch_size, 特征数),这里特征数为375x = self.full_c(x)  # 输出10类的预测概率(未经过softmax)return x

网络整体流程:输入图像→卷积层提取特征→池化层降维→残差块增强特征→全连接层输出分类结果。

x.view(size, -1)将卷积后的三维特征图(batch, channel, height, width)展平为二维张量(batch, 特征数),才能输入全连接层。

全连接层输入维度375是根据前面的特征图尺寸计算的(具体为:经过多次卷积和池化后,特征图尺寸为 5x5,通道数 15,5×5×15=375)。

4. 模型初始化

model = ResNet().to(device)  # 实例化模型,并移动到之前选择的设备(GPU/CPU)

5. 训练与测试函数

5.1 训练函数(train)
def train(dataloader, model, loss_fn, optimizer):model.train()  # 设置模型为训练模式(启用dropout、批归一化更新等)batch_size_num = 1  # 批次计数器for X, y in dataloader:X, y = X.to(device), y.to(device)  # 数据移到设备# 前向传播:计算预测值pred = model(X)# 计算损失(预测值与真实标签的差异)loss = loss_fn(pred, y)# 反向传播与参数更新optimizer.zero_grad()  # 清空上一轮梯度loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 每100个批次打印一次损失loss_value = loss.item()if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number: {batch_size_num}]")batch_size_num += 1

核心流程:前向传播计算预测→计算损失→反向传播求梯度→优化器更新参数。

model.train():启用训练模式(例如,若有 dropout 层会随机丢弃神经元)。

5.2 测试函数(test)
best_acc = 0  # 记录最佳准确率def test(dataloader, model, loss_fn):global best_acc  # 引用全局变量size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)  # 测试集批次数model.eval()  # 设置模型为评估模式(关闭dropout等)test_loss = 0  # 总测试损失correct = 0  # 正确分类的样本数with torch.no_grad():  # 关闭梯度计算(节省内存,加速计算)for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)  # 预测test_loss += loss_fn(pred, y).item()  # 累加损失# 计算正确数:预测最大值的索引(类别)与真实标签一致correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_pj_loss = test_loss / num_batchestest_acy = correct / size * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuracy: {test_acy:>5.2f}%")# 保存准确率最高的模型if correct > best_acc:best_acc = correcttorch.save(model.state_dict(), 'best.pth')  # 保存模型参数(而非整个模型)print(f"保存最佳模型,准确率: {test_acy:>5.2f}%")else:print(f"当前准确率未超过最佳,最佳准确率: {best_acc/size*100:>5.2f}%")

核心作用:评估模型在测试集上的性能(损失和准确率),并保存表现最好的模型。

model.eval():切换到评估模式(例如,关闭 dropout,固定批归一化参数)。

with torch.no_grad():关闭梯度计算,减少内存占用,加速测试过程。

模型保存用model.state_dict():仅保存参数(权重和偏置),而非整个模型结构,更轻量且灵活。

6. 训练配置与执行

loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失(适合分类任务,内置softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器,学习率0.001# 训练10个epoch
epochs = 10
for j in range(epochs):print(f"Epoch {j+1}\n----------")train(train_dataloader, model, loss_fn, optimizer)  # 训练一轮test(test_dataloader, model, loss_fn)  # 测试一轮

CrossEntropyLoss:适用于多分类任务,自动对输出进行 softmax 处理,并计算与标签的交叉熵。

Adam:一种常用的优化器,结合了动量和自适应学习率,收敛速度快且稳定。

epoch:完整遍历一次训练集的次数,这里设置为 10 次,每次训练后测试模型性能。

总结

这段代码实现了一个简化版的残差网络,用于 MNIST 手写数字识别。核心亮点包括:

使用残差块解决深层网络梯度消失问题;

完整的训练 - 测试流程(含设备自动选择、损失计算、参数更新、模型保存);

符合 PyTorch 最佳实践(如train()/eval()模式切换、torch.no_grad()关闭梯度等)。

通过训练,模型通常能达到 98% 以上的准确率,残差结构相比普通卷积网络能更高效地学习特征。

http://www.xdnf.cn/news/19885.html

相关文章:

  • 【面试题】LangChain与LlamaIndex核心概念详解
  • 聚焦GISBox矢量服务:数据管理、数据库连接与框架预览全攻略
  • 分布式电源接入电网进行潮流计算
  • Linux笔记---UDP套接字实战:简易聊天室
  • 服务器不支持node.js16以上版本安装?用Docker轻松部署Node.js 20+环境运行Strapi项目
  • 新规则,新游戏:AI时代下的战略重构与商业实践
  • 安全领域必须关注每年发布一次“最危险的25种软件弱点”清单 —— CWE Top 25(内附2024 CWE Top 25清单详情)
  • Boost搜索引擎 数据清洗与去标签(1)
  • 【OpenHarmony文件管理子系统】文件访问接口mod_fs解析
  • ECMAScript(2)核心语法课件(Node.js/React 环境)
  • uniapp的上拉加载H5和小程序
  • PDF.AI-与你的PDF文档对话
  • C++虚函数虚析构函数纯虚函数的使用说明和理解
  • redisson延迟队列报错Sync methods can‘t be invoked from async_rx_reactive listeners
  • 快速排序算法详解
  • 【mysql】SQL自连接实战:查询温度升高的日期
  • 三维多相机光场扫描:打造元宇宙时代的“数字自我”
  • React学习教程,从入门到精通, React 嵌套组件语法知识点(10)
  • 公司机密视频泄露频发?如何让机密视频只在公司内部播放
  • 数据采集机器人哪家好?2025 年实测推荐:千里聆 RPA 凭什么成企业首选?
  • 机器人智能控制领域技术路线
  • 嵌入式 - 硬件:51单片机(3)uart串口
  • 【Java EE进阶 --- SpringBoot】Spring IoC
  • 鸿蒙:从图库选择图片并上传到服务器
  • 什么情况下会用到ConcurrentSkipListMap
  • 【系统架构设计(15)】软件架构设计一:软件架构概念与基于架构的软件开发
  • PDF Reader 编辑阅读工具(Mac中文)
  • Linux 常用命令全解析:从入门到实战的必备指南
  • TypeScript 增强功能大纲 (相对于 ECMAScript)
  • 如何轻松地将联系人从 Mac 同步到 iPhone