当前位置: 首页 > news >正文

用PyTorch搭建卷积神经网络实现MNIST手写数字识别

用PyTorch搭建卷积神经网络实现MNIST手写数字识别

在深度学习领域,卷积神经网络(Convolutional Neural Network,简称CNN)是处理图像数据的强大工具。它通过卷积层、池化层和全连接层等组件,自动提取图像特征,在图像分类、目标检测等任务中表现卓越。本文将使用PyTorch框架,搭建一个CNN模型来实现MNIST手写数字识别,并详细解析每一步代码。

一、MNIST数据集介绍

MNIST数据集是深度学习领域经典的入门数据集,包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试。这些图像均为灰度图,尺寸是28x28像素,并且已经做了居中处理,这在一定程度上减少了预处理的工作量,能够加快模型的训练和运行速度。

二、环境准备与数据加载

2.1 导入必要的库

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

上述代码导入了PyTorch的核心库、神经网络模块、数据加载工具以及用于图像数据处理和数据集管理的库。

2.2 下载并加载数据集

training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor()
)test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)

通过datasets.MNIST函数分别下载训练集和测试集。root参数指定数据下载的路径;train=True表示下载训练集数据,train=False则表示下载测试集数据;download=True确保如果数据尚未下载,会自动进行下载;transform=ToTensor()将图像数据转换为PyTorch能够处理的张量格式。

2.3 数据可视化

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img, label = training_data[i + 59000]figure.add_subplot(3, 3, i + 1)plt.title(label)plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

这段代码使用matplotlib库展示了训练数据集中的部分手写数字图像,通过plt.imshow函数将张量格式的图像数据可视化,直观感受MNIST数据集的内容。

2.4 创建数据加载器

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

DataLoader用于将数据集打包成批次,batch_size参数指定每个批次包含的数据样本数量。将数据集分成批次进行训练,能够有效减少内存使用,并提高训练速度。

三、设备配置

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这段代码检测当前设备是否支持GPU(CUDA)或苹果M系列芯片的GPU(MPS),如果都不支持,则使用CPU进行计算。后续模型和数据都会被移动到选定的设备上运行,以充分利用硬件资源加速训练。

四、定义卷积神经网络模型

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2))self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU())self.out = nn.Linear(64 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)output = self.out(x)return output

在这个自定义的CNN类中,继承自nn.Module__init__方法中定义了网络的结构:

  • 卷积层(nn.Conv2d:用于提取图像特征,通过设置in_channels(输入通道数)、out_channels(输出通道数,即卷积核个数)、kernel_size(卷积核大小)、stride(步长)和padding(填充)等参数,控制卷积操作。
  • 激活函数层(nn.ReLU:引入非线性,增强网络的表达能力。
  • 池化层(nn.MaxPool2d:对特征图进行下采样,减少数据量和计算量,同时保留主要特征。
  • 全连接层(nn.Linear:将卷积层和池化层提取的特征映射到输出类别(MNIST数据集中有10个数字类别)。

forward方法定义了数据在网络中的前向传播路径,确保数据按照网络结构依次经过各层处理,最终输出预测结果。

五、训练与测试模型

5.1 定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

nn.CrossEntropyLoss是适用于多分类任务的交叉熵损失函数,用于计算模型预测结果与真实标签之间的差距。torch.optim.Adam是一种常用的优化器,通过调整模型的参数(model.parameters())来最小化损失函数,lr参数设置学习率,控制参数更新的步长。

5.2 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 == 0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num += 1

在训练函数中:

  • model.train()将模型设置为训练模式,此时模型中的一些层(如Dropout层)会按照训练规则工作。
  • 遍历数据加载器中的每一个批次数据,将数据和标签移动到指定设备上。
  • 通过模型进行预测,计算损失值。
  • 使用optimizer.zero_grad()清零梯度,loss.backward()进行反向传播计算梯度,optimizer.step()根据梯度更新模型参数。
  • 每隔100个批次,打印当前的损失值,以便观察训练过程中的损失变化。

5.3 测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

测试函数中:

  • model.eval()将模型设置为测试模式,关闭一些在训练过程中起作用但在测试时不需要的操作(如Dropout)。
  • 使用with torch.no_grad()上下文管理器,关闭梯度计算,因为在测试阶段不需要更新模型参数,这样可以节省计算资源。
  • 遍历测试数据,计算每个批次的损失值并累加,同时统计预测正确的样本数量。
  • 最后计算并打印测试集上的平均损失和准确率,评估模型的性能。

5.4 执行训练和测试

epoch = 9
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)

通过设置训练轮数(epoch),循环调用训练函数进行模型训练,每一轮训练结束后,调用测试函数评估模型在测试集上的性能。

六、总结

本文通过详细的代码解析,展示了如何使用PyTorch搭建一个简单的卷积神经网络来实现MNIST手写数字识别任务。从数据加载、模型定义,到训练和测试,每一个步骤都体现了CNN在图像分类任务中的核心思想和实现方法。通过不断调整模型结构、超参数等,还可以进一步提升模型的性能。卷积神经网络在图像领域的应用远不止于此,它在更复杂的图像任务和其他领域也有着广泛的应用前景,希望本文能为大家深入学习深度学习提供一个良好的开端。

http://www.xdnf.cn/news/263809.html

相关文章:

  • 生成式 AI 的工作原理
  • Elasticsearch 中的索引模板:如何使用可组合模板
  • 【在Spring Boot中集成Redis】
  • 【赵渝强老师】TiDB生态圈组件
  • 3D人物关系图开发实战:Three.js实现自动旋转可视化图谱(附完整代码)
  • 人工智能助力工业制造:迈向智能制造的未来
  • 别样健康养生之道
  • AI 与生物技术的融合:开启精准医疗的新纪元
  • ros2 humble 控制真实机械臂(以lerobot为例)
  • 一种基于重建前检测的实孔径雷达实时角超分辨方法——论文阅读
  • **Java面试大冒险:谢飞机的幽默与技术碰撞记**
  • 做响应式布局网页多简单
  • AI生成视频检测方法及其相关研究
  • WebRTC 服务器之Janus概述和环境搭建
  • Spring MVC入门
  • 第12章:精神力的禁忌边界
  • 强化学习--3.值函数的方法(贝尔曼方程)
  • 直播推流拉流Token验证流程(直播服务器:SRS,验证服务器:EGGS(nodejs))
  • 智能决策支持系统的系统结构:四库架构与融合范式
  • k8s笔记——kubebuilder工作流程
  • 嵌入式硬件篇---STM32F103C8T6STM32F103RCT6
  • Flink 的状态机制
  • Qt中实现工厂模式
  • 音视频开源项目列表
  • 【2025年】MySQL面试题总结
  • 实战探讨:为什么 Redis Zset 选择跳表?
  • xLua笔记
  • 55.[前端开发-前端工程化]Day02-包管理工具npm等
  • Oracle 11g通过dg4odbc配置dblink连接神通数据库
  • Oracle RAC ‘Metrics Global Cache Blocks Lost‘告警解决处理