当前位置: 首页 > backend >正文

深度学习——卷积神经网络(PyTorch 实现 MNIST 手写数字识别案例)

传统机器学习实现同一项目:(只有定义神经网络的部分不同,解析更加详细)

深度学习——神经网络(PyTorch 实现 MNIST 手写数字识别案例)-CSDN博客https://blog.csdn.net/2302_78022640/article/details/150781035


案例教学:使用 PyTorch 构建卷积神经网络实现 MNIST 手写数字识别

本文将通过一个完整的 PyTorch 实现案例,带你逐步理解如何下载数据集、构建神经网络模型、训练和测试模型,并最终实现手写数字识别。案例选用的是 MNIST 数据集,这是一个入门级的图像分类任务。


一、准备工作

首先,导入所需的依赖库:

# import torch
# print(torch.__version__)'''MNIST包含70000张手写数字图像:60000用于训练,10000用于测试图像是灰度的,28×28像素的,并且居中的,以减少预处理和加快运行
'''
import torch
from torch import nn    #导入神经网络模块
from torch.utils.data import DataLoader  #数据包管理工具,打包数据
from torchvision import  datasets  #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor  #数据转换,张量,将其他类型的数据转换为tensor张量,numpy array
  • torch:PyTorch 的核心包,提供张量运算和深度学习构建的基础。

  • nn:神经网络模块,用于搭建层结构(卷积层、全连接层等)。

  • DataLoader:数据加载器,可以自动打包数据,支持批量读取。

  • datasets:提供常用的数据集(如 MNIST、CIFAR10)。

  • ToTensor:将图片转换为张量格式,方便神经网络使用。


二、加载数据集

'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST(root="data", train=True, download=True, transform=ToTensor(), 
)   '''下载测试数据集(包含测试图片+标签)'''
test_data = datasets.MNIST(root="data", train=False, download=True, transform=ToTensor(), 
)   
print(len(training_data))

这里使用 MNIST 数据集

  • train=True:加载训练集(60000 张图片)。

  • train=False:加载测试集(10000 张图片)。

  • transform=ToTensor():将数据转为张量,方便送入模型。

输出结果是 60000,表示训练集中有 60000 张图像。


三、数据可视化(可选)

代码中提供了展示部分图像的功能:

# '''展示手写数字图片,把训练集中的59000张图片展示'''
# from matplotlib import pyplot as plt
# figure = plt.figure()
# for i in range(9):
#     img,label = training_data[i+59000] 
#     figure.add_subplot(3,3,i+1) 
#     plt.title(label)
#     plt.axis("off")  
#     plt.imshow(img.squeeze(),cmap="gray") 
#     a = img.squeeze()  
# plt.show()

这段代码会在一个 3x3 的网格中展示 9 张手写数字。img.squeeze() 是为了去掉多余的维度,使图像能被 imshow 正确显示。


四、创建数据加载器

'''创建数据DataLoader(数据加载器)'''
train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
for X,y in test_dataloader:print(f"Shape of X[N,C,H,W]:{X.shape}")print(f"Shape of y: f{y.shape} {y.dtype}")break
  • batch_size=64:每次从数据集中读取 64 张图片作为一个批次。

  • 优点:节省内存、加快训练速度。

输出结果类似:

Shape of X[N,C,H,W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

解释:

  • 64:批次大小

  • 1:通道数(灰度图)

  • 28x28:图像大小


五、选择运行设备(CPU/GPU)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
  • cuda:使用 NVIDIA GPU。

  • mps:Apple M 系列芯片的 GPU。

  • cpu:若无 GPU,则使用 CPU。


六、定义卷积神经网络(CNN)

''' 定义神经网络  类的继承这种方式'''
class CNN(nn.Module): def __init__(self):   super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(1,16,3,1,1),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16, 16, 3, 1, 1),nn.ReLU(),nn.Conv2d(16, 32, 3, 1, 1),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.ReLU(),nn.MaxPool2d(kernel_size=2), )self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1),nn.ReLU(),nn.Conv2d(64, 64, 3, 1, 1),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self,x):  x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)  x = x.view(x.size(0), -1) output = self.out(x)return outputmodel = CNN().to(device)
print(model)

模型结构说明:输入1*28*28(64 张图片作为一个批次。故 64*1*28*28)

  1. conv1:卷积 + ReLU + 池化 → 输出 16*14*14

  2. conv2:多层卷积 + ReLU + 池化 → 输出 32*7*7

  3. conv3:卷积层 → 输出 64*7*7

  4. Linear 全连接层:输入 64*7*7,输出 10(对应 0~9 的数字分类)。

也可以定义其他卷积神经网络:(多条一条各种参数,加深了解)

class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(#容器,添加网络层nn.Conv2d(in_channels=1,out_channels = 16,kernel_size = 5,stride = 1,padding = 2,),nn.ReLU(),nn.MaxPool2d(kernel_size = 2),)self.conv2 = nn.Sequential(  # 容器,添加网络层nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(64, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self,x):  x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)  x = x.view(x.size(0), -1) output = self.out(x)return output


七、训练函数

def train(dataloader,model,loss_fn,optimizer):model.train() batch_size_num = 1for X,y in dataloader:              X,y = X.to(device),y.to(device) pred = model.forward(X)         loss = loss_fn(pred,y)          optimizer.zero_grad()           loss.backward()                 optimizer.step()                loss_value = loss.item()        if batch_size_num %100 ==0:print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1

核心步骤:

  1. 前向传播:计算预测结果 pred

  2. 计算损失:loss_fn(pred,y)

  3. 反向传播:loss.backward() 计算梯度。

  4. 参数更新:optimizer.step()


八、测试函数

def Test(dataloader,model,loss_fn):size = len(dataloader.dataset)  num_batches = len(dataloader)  model.eval()        test_loss,correct =0,0with torch.no_grad():       for X,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)test_loss += loss_fn(pred,y).item() correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batches correct /= size  print(f"Test result: \n Accuracy:{(100*correct)}%, Avg loss:{test_loss}")

这里使用 eval() 模式no_grad()

  • eval():固定参数,关闭 dropout、BN 等训练机制。

  • no_grad():关闭梯度计算,节省内存和计算量。


九、定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(),lr=0.005) 
  • CrossEntropyLoss:常用于分类任务。

  • Adam 优化器:比 SGD 收敛更快。


十、模型训练与测试

epochs = 10
for t in range(epochs):print(f"epoch {t+1}\n---------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done!")
Test(test_dataloader,model,loss_fn)
  • 训练 10 轮 (epochs=10)

  • 每轮训练后,会不断更新模型参数。

  • 最后调用 Test() 测试在测试集上的准确率和损失。

一般情况下,训练 10 轮后,模型在 MNIST 上的准确率可以达到 98%~99% 左右


总结

通过本案例,我们完成了以下步骤:

  1. 下载并加载 MNIST 数据集。

  2. 构建卷积神经网络(CNN)模型。

  3. 使用训练集训练模型,并记录损失变化。

  4. 使用测试集评估模型性能。

这是一个非常经典的 深度学习入门案例,帮助初学者理解卷积神经网络在图像分类中的应用。

http://www.xdnf.cn/news/18948.html

相关文章:

  • pcl_案例2 叶片与根茎的分离
  • 机器视觉学习-day09-图像矫正
  • Day30 多线程编程 同步与互斥 任务队列调度
  • leetcode_73 矩阵置零
  • 【LLM】Transformer模型中的MoE层详解
  • vue布局
  • 架构设计——云原生与分布式系统架构
  • Android中设置RecyclerView滑动到指定条目位置
  • 搜维尔科技核心产品矩阵涵盖从硬件感知到软件渲染的全产品供应链
  • 万博智云联合华为云共建高度自动化的云容灾基线解决方案
  • 【Python开源环境】Anaconda/Miniconda
  • 【数据结构与算法】(LeetCode)141.环形链表 142.环形链表Ⅱ
  • 重置 Windows Server 2019 管理员账户密码
  • 深入理解QLabel:Qt中的文本与图像显示控件
  • 国产的服务器
  • 机器学习回顾(一)
  • Day16_【机器学习—KNN算法】
  • 小白入门:支持深度学习的视觉数据库管理系统
  • 解构与重构:“真人不露相,露相非真人” 的存在论新解 —— 论 “真在” 的行为表达本质
  • c++ 观察者模式 订阅发布架构
  • Visual Scope (Serial_Digital_Scope V2) “串口 + 虚拟示波器” 工具使用记录
  • JavaScript中的BOM,DOM和事件
  • Spring Boot 实战:接入 DeepSeek API 实现问卷文本优化
  • 底层音频编程的基本术语 PCM 和 Mixer
  • 数据分析学习笔记4:加州房价预测
  • 腕上智慧健康管家:华为WATCH 5与小艺的智美生活新范式
  • 音频转PCM
  • curl、python-requests、postman和jmeter的对应关系
  • AR培训系统:油气行业的安全与效率革新
  • frp 一个高性能的反向代理服务