当前位置: 首页 > news >正文

从零开始构建卷积神经网络(CNN)进行MNIST手写数字识别

在深度学习领域,卷积神经网络(CNN)凭借其对图像特征的出色提取能力,成为图像识别任务的核心模型。本文将基于 PyTorch 框架,从零搭建一个卷积神经网络,完成 MNIST 手写数字数据集的识别任务,并详细讲解从数据加载、模型构建到训练与测试的完整流程。

一、环境准备与数据加载

1.1 依赖库导入

首先需要导入实验所需的 Python 库,包括 PyTorch 核心库、数据加载与处理相关库,以及数据集模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

1.2 查看 PyTorch 版本

确认当前使用的 PyTorch 版本,确保代码兼容性:

print(torch.__version__)

1.3 MNIST 数据集加载

MNIST 是手写数字数据集,包含 60000 张训练图片和 10000 张测试图片,每张图片为 28×28 像素的灰度图,标签为 0-9 的数字。通过torchvision.datasets.MNIST可直接下载并加载数据集,ToTensor()会将图像转换为 PyTorch 支持的张量格式(取值范围从 0-255 归一化到 0-1):

# 加载训练集
training_data = datasets.MNIST(root='data',  # 数据集保存路径train=True,   # 标记为训练集download=True,  # 若本地无数据集则自动下载transform=ToTensor(),  # 数据转换:图像→张量
)# 加载测试集
test_data = datasets.MNIST(root='data',train=False,  # 标记为测试集download=True,transform=ToTensor(),
)# 查看训练集样本数量
print(f"训练集样本数:{len(training_data)}")
print(f"测试集样本数:{len(test_data)}")

1.4 数据批量加载(DataLoader)

为了提高训练效率,使用DataLoader将数据集按批次(batch)划分,每次训练时批量读取数据,同时支持数据打乱(仅训练集)和并行加载:

# 训练集DataLoader:批量大小64,打乱数据
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# 测试集DataLoader:批量大小64,无需打乱
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)# 查看数据张量形状
for X, y in test_dataloader:print(f"输入图像张量形状 [批次大小, 通道数, 高度, 宽度]:{X.shape}")print(f"标签张量类型:{y.dtype}")print(f"标签示例:{y[:5]}")break

输出结果中,X.shape(64, 1, 28, 28),代表每批次包含 64 张 1 通道(灰度图)、28×28 像素的图像;y为标签张量,类型为整数,对应手写数字的真实值。

1.5 设备配置(CPU/GPU)

PyTorch 支持 CPU 和 GPU 训练,通过以下代码自动检测并使用可用的计算设备(优先 GPU,其次 CPU):

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"使用的计算设备:{device}")

二、卷积神经网络模型构建

根据 PPT 中介绍的 CNN 核心结构(卷积层、池化层、全连接层),我们设计一个包含 3 个卷积模块和 1 个全连接层的网络,用于 MNIST 数字识别。

2.1 模型结构设计

模型整体流程:输入图像→卷积层 1(含 ReLU 激活 + 最大池化)→卷积层 2(含 ReLU 激活)→卷积层 3(含 ReLU 激活 + 最大池化)→展平→全连接层→输出。各层参数设计参考 PPT 中 “卷积层计算原理” 和 “池化层作用”:

  • 卷积层:使用 3×3 或 5×5 卷积核,通过padding保持特征图尺寸,out_channels逐步增加以提取更复杂的特征;
  • 池化层:采用最大池化(MaxPool2d),步长为 2,将特征图尺寸缩小一半,减少参数数量和计算量;
  • 全连接层:将池化后的特征图展平为一维向量,映射到 10 个输出(对应 0-9 数字类别)。

2.2 模型代码实现

通过nn.Module定义自定义网络类,使用nn.Sequential简化层的堆叠:

class SequentialNetwork(nn.Module):def __init__(self):super().__init__()# 卷积模块1:1→16通道,5×5卷积核,padding=2(保持尺寸),后接ReLU和最大池化self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,    # 输入通道数(灰度图为1)out_channels=16,  # 输出通道数(卷积核数量)kernel_size=5,    # 卷积核大小stride=1,         # 步长padding=2,        # 边缘填充,使输出尺寸=输入尺寸),nn.ReLU(),  # 激活函数,引入非线性nn.MaxPool2d(kernel_size=2)  # 最大池化,尺寸缩小为14×14)# 卷积模块2:16→32通道,5×5卷积核,无池化(保持14×14尺寸)self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),)# 卷积模块3:32→64通道,5×5卷积核,后接最大池化(尺寸缩小为7×7)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 全连接层:将64×7×7的特征图展平后,映射到10个类别self.out = nn.Linear(64 * 7 * 7, 10)# 前向传播:定义数据在网络中的流动路径def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # 展平:(batch_size, 64*7*7)x = self.out(x)return x# 创建模型实例,并移动到指定设备
model = SequentialNetwork().to(device)
print("卷积神经网络模型结构:")
print(model)

三、训练与测试函数定义

3.1 训练函数(train)

训练函数负责模型的迭代训练,包括前向传播(计算预测值)、损失计算、反向传播(更新梯度)和参数优化。根据 PPT 中 “训练控制模块” 的思路,需注意:

  • 调用model.train()切换到训练模式(启用 dropout、批量归一化等训练特有的操作);
  • 每次迭代前清空梯度(optimizer.zero_grad()),避免梯度累积;
  • 每 100 个批次打印一次损失值,监控训练进度。
def train(dataloader, model, loss_fn, optimizer):model.train()  # 切换到训练模式batch_size_num = 1  # 批次计数器for X, y in dataloader:# 将数据移动到计算设备X, y = X.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(X)# 计算损失(交叉熵损失,适用于分类任务)loss = loss_fn(pred, y)# 反向传播与参数优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 反向传播计算梯度optimizer.step()       # 更新模型参数# 每100个批次打印损失loss_value = loss.item()if batch_size_num % 100 == 0:print(f"训练批次:{batch_size_num:>4d} | 损失值:{loss_value:>7f}")batch_size_num += 1

3.2 测试函数(test)

测试函数用于评估模型在测试集上的性能,包括计算准确率和平均损失。根据 PPT 中 “模型评估” 的要求,需注意:

  • 调用model.eval()切换到测试模式(禁用 dropout、固定批量归一化参数);
  • 使用torch.no_grad()禁用梯度计算,减少内存占用和计算时间;
  • 统计所有测试样本的正确预测数,计算准确率。
def test(dataloader, model, loss_fn):model.eval()  # 切换到测试模式size = len(dataloader.dataset)  # 测试集总样本数num_batches = len(dataloader)    # 测试集批次数test_loss, correct = 0, 0        # 总损失和正确预测数# 禁用梯度计算with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累加损失和正确预测数test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"\n测试集结果:")print(f"准确率:{(100 * correct):>0.1f}% | 平均损失:{test_loss:>8f}")

四、模型训练与结果评估

4.1 损失函数与优化器配置

  • 损失函数:选用nn.CrossEntropyLoss(),适用于多分类任务,内置了 Softmax 激活函数;
  • 优化器:选用 Adam 优化器(比 SGD 收敛更快),学习率设置为 0.01(参考 PPT 中 “调整学习率” 的建议,初始学习率不宜过大或过小)。
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

4.2 迭代训练

设置训练轮次(epochs)为 20,每轮训练后在测试集上评估模型性能:

# 训练轮次
epochs = 20print(f"开始训练(共{epochs}轮):")
for t in range(epochs):print(f"\n==================== 第{t+1}轮训练 ====================")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)print("\n训练完成!")

4.3 预期结果与分析

在 MNIST 数据集上,该卷积神经网络经过 20 轮训练后,通常可达到 98% 以上的测试准确率。若准确率较低,可参考 PPT 中的优化方向:

  • 数据增强:添加随机旋转、裁剪等操作(如 PPT “数据增强” 模块),增加训练数据多样性;
  • 调整学习率:使用torch.optim.lr_scheduler动态调整学习率(如 StepLR、CosineAnnealingLR);
  • 模型加深:增加卷积层数量或使用预训练模型(如 ResNet,参考 PPT “迁移学习” 模块)。

五、总结

本文基于 PyTorch 实现了一个简易的卷积神经网络,完成了 MNIST 手写数字识别任务,核心流程可总结为:

  1. 数据加载:使用datasets.MNISTDataLoader处理数据,支持批量加载;
  2. 模型构建:遵循 CNN 的核心结构(卷积层 + 池化层 + 全连接层),通过nn.Module自定义网络;
  3. 训练与测试:分别定义训练和测试函数,监控损失和准确率,评估模型性能;
  4. 优化方向:可结合数据增强、动态学习率调整、迁移学习等技术进一步提升模型性能。

通过本文的实践,不仅掌握了卷积神经网络的基本原理(如 PPT 中讲解的卷积操作、池化作用、感受野等),还熟悉了 PyTorch 框架的核心用法,为后续更复杂的图像识别任务(如目标检测、图像分割)奠定基础。

http://www.xdnf.cn/news/1390717.html

相关文章:

  • 彻底弄清URI、URL、URN的关系
  • BGP路由协议(二):报文的类型和格式
  • OpenAI宣布正式推出Realtime API
  • 网络_协议
  • Qt事件_xiaozuo
  • 快速深入理解zookeeper特性及核心基本原理
  • Replay – AI音乐伴奏分离工具,自动分析音频内容、提取主唱、人声和伴奏等音轨
  • rust打包增加图标
  • 常见视频编码格式对比
  • 【3D入门-指标篇下】 3D重建评估指标对比-附实现代码
  • 哈希算法完全解析:从原理到实战
  • Python OpenCV图像处理与深度学习
  • 网页提示UI操作-适应提示,警告,信息——仙盟创梦IDE
  • 【贪心算法】day4
  • 实现自己的AI视频监控系统-第二章-AI分析模块5(重点)
  • 【开题答辩全过程】以 基于SpringBootVue的智能敬老院管理系统为例,包含答辩的问题和答案
  • 为什么特征缩放对数字货币预测至关重要
  • 克隆态驱动给用户态使用流程
  • Python 异步编程:await、asyncio.gather 和 asyncio.create_task 的区别与最佳实践
  • 【DeepSeek】公司内网部署离线deepseek+docker+ragflow本地模型实战
  • 软考-系统架构设计师 办公自动化系统(OAS)详细讲解
  • 【C语言】深入理解指针(2)
  • [打包压缩] gzip压缩和解压缩介绍
  • webservice在进行run maven build中出现java.lang.ClassCastException错误
  • C++基础(⑤删除链表中的重复节点(链表 + 遍历))
  • 【C++闯关笔记】STL:vector的学习与使用
  • Spring Security 传统 web 开发场景下开启 CSRF 防御原理与源码解析
  • CorrectNav:用错误数据反哺训练的视觉语言导航新突破
  • Apache服务器IP 自动跳转域名教程​
  • electron-vite 配合python