当前位置: 首页 > ai >正文

lesson05-手写数据问题案例实战(理论+代码)

在本篇文章中,我们将详细探讨如何使用简单的神经网络模型对手写数字进行识别。我们将从数据准备开始,介绍整个流程直至模型推理。 

一、准备数据集

       首先,我们需要一个合适的数据集来进行训练和测试。这里我们选择的是著名的 MNIST 数据集,它包含了大量的手写数字图像(0-9),每个数字有7000张图像,总共60,000张用于训练,10,000张用于测试。

二、没有深度学习只有映射

在这个阶段,我们将不依赖于复杂的深度学习架构,而是通过简单的线性映射来实现函数逼近。输入是一个28x28像素的灰度图像,展平后形成一个长度为784的一维向量X。

 三、损失函数

为了衡量预测值与真实值之间的差距,我们需要定义一个损失函数。在这里,我们选择了欧几里得距离作为损失函数,计算预测输出H3与实际标签Y之间的差异。

 四、非线性因子

为了让模型具有更强的表现力,我们引入了非线性激活函数ReLU。这有助于捕捉输入数据中的复杂模式,而非简单地执行线性变换。

五、梯度下降

为了最小化损失函数,我们采用梯度下降算法调整权重和偏置项。目标是最小化预测值与真实值之间的差异。

六、推理

最后,在完成模型训练之后,我们可以用该模型对新的输入进行预测。对于给定的新输入X1,通过前向传播得到预测结果,并根据最大概率确定最终的分类结果。

 

总结

本文简要介绍了如何构建一个基本的神经网络用于手写数字识别任务。这个例子相对简单,但它涵盖了机器学习项目的基本步骤,包括数据预处理、模型设计、训练过程以及最终的推理应用。希望这篇文章能够帮助初学者理解并入门这一领域。

 代码案例:

🔍 一、导入库

import torch
from torch import nn
from torch.nn import functional as F
from torch import optimimport torchvision
from matplotlib import pyplot as pltfrom utils import plot_image, plot_curve, one_hot
  • torch: PyTorch 的核心库。
  • nn: 提供神经网络层,如线性层、卷积层等。
  • F: 包含激活函数、损失函数等。
  • optim: 提供优化器,如 SGD、Adam 等。
  • torchvision: 提供常用数据集(如 MNIST)和图像变换工具。
  • matplotlib.pyplot: 用于绘图。
  • utils: 自定义辅助函数:
    • plot_image: 显示图像样本。
    • plot_curve: 绘制训练 loss 曲线。
    • one_hot: 将类别标签转换为 one-hot 编码。

📦 二、设置 batch_size 并加载数据集

batch_size = 512

加载训练集

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
  • 使用 DataLoader 加载 MNIST 数据集。
  • transform 对图像进行预处理:
    • ToTensor(): 将图像转为 [0,1] 范围内的张量。
    • Normalize((0.1307,), (0.3081)): 对灰度图做标准化(均值和标准差来自 MNIST 训练集统计)。
  • shuffle=True: 每个 epoch 开始时打乱数据。 

加载测试集 

test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)

      测试集不需要打乱顺序。

查看一个 batch 的数据结构

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')
  • x.shape[512, 1, 28, 28] → 表示 batch_size=512,单通道(灰度图),28x28 像素。
  • y.shape[512] → 每个样本对应的数字标签(0~9)。
  • plot_image():可视化一批次图像和对应标签。

🧠 三、构建神经网络模型

class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xw+bself.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x
  • 定义了一个三层全连接神经网络:
    • 输入层:28×28 = 784 维。
    • 隐藏层1:256 个神经元。
    • 隐藏层2:64 个神经元。
    • 输出层:10 个神经元(对应 10 个数字类别)。
  • 使用 ReLU 激活函数。
  • 模型结构简单但足够完成 MNIST 分类任务。

⚙️ 四、定义优化器

net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  • 创建模型实例 net
  • 使用随机梯度下降(SGD)作为优化器,学习率 lr=0.01,动量 momentum=0.9 可以加速收敛。

📈 五、训练模型

train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):x = x.view(x.size(0), 28*28)  # 展平输入out = net(x)                  # 前向传播y_onehot = one_hot(y)         # 标签 one-hot 编码loss = F.mse_loss(out, y_onehot)  # 使用均方误差损失optimizer.zero_grad()         # 清空梯度loss.backward()               # 反向传播optimizer.step()              # 参数更新train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)
  • 前向传播:输入图像展平后送入网络,输出预测结果。
  • 标签编码:将整数标签转换为 one-hot 向量,便于计算损失。
  • 损失函数:使用均方误差(MSE)代替交叉熵损失(虽然不太推荐,但对简单任务也能工作)。
  • 反向传播:计算梯度并更新参数。
  • 记录 loss:绘制训练曲线,观察模型是否在学习。

🧪 六、测试模型性能

total_correct = 0
for x, y in test_loader:x = x.view(x.size(0), 28*28)out = net(x)pred = out.argmax(dim=1)  # 取最大概率的类别作为预测correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)
  • 在测试集上评估模型准确率。
  • 使用 argmax() 获取预测类别。
  • eq() 判断预测与真实标签是否一致,求和得到正确数。
  • 最终输出测试准确率。

🖼️ 七、可视化测试结果

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')
  • 取出一批测试图像,用模型预测,显示预测结果。

✅ 总结

模块功能
数据加载使用 DataLoader + transforms 加载并预处理 MNIST 数据
模型定义构建三层全连接网络,使用 ReLU 激活函数
损失函数使用 MSE Loss(建议后期改为 CrossEntropyLoss)
优化器使用带动量的 SGD
训练流程前向传播、计算损失、反向传播、更新参数
测试流程评估模型准确率,并可视化预测结果

完整代码demo: 

import  torch
from    torch import nn
from    torch.nn import functional as F
from    torch import optimimport  torchvision
from    matplotlib import pyplot as pltfrom    utils import plot_image, plot_curve, one_hotbatch_size = 512# step1. load dataset
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xw+bself.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):# x: [b, 1, 28, 28], y: [512]# [b, 1, 28, 28] => [b, 784]x = x.view(x.size(0), 28*28)# => [b, 10]out = net(x)# [b, 10]y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10==0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]total_correct = 0
for x,y in test_loader:x  = x.view(x.size(0), 28*28)out = net(x)# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

http://www.xdnf.cn/news/13546.html

相关文章:

  • linux回收站
  • 爱普生TG5032SGN同步以太网的高精度时钟解决方案
  • P2840 纸币问题 2
  • 华为OD机考-数字螺旋矩阵(JAVA 2025B卷)
  • Python前端系列(三)
  • DATABASE 结构迁移实战手册:脚本生成、分类与部署全流程详解
  • 华为云Flexus+DeepSeek征文|华为云CCE容器高可用部署Dify LLM应用后的资源释放指南
  • 掌握Linux进程替换:从原理到实战(自定义shell)
  • 笔试模拟day1
  • 随记 使用certbot申请ssl证书
  • 跨域的本质与实战:从理论到松鼠短视频系统的演进-优雅草卓伊凡|卢健bigniu
  • 数据库游标:逐行处理数据的“手术刀”——从原理到实战的深度解析
  • 开关电源-KA3842A芯片的电路分析
  • CSS“多列布局”
  • 电池充放电容量检测:能否精准锁定电池真实性能?
  • PSCAD closed loop buck converter
  • 打卡day51
  • CMake安装教程
  • 2025GEO供应商排名深度解析:源易信息构建AI生态优势
  • 新德通:光通信领域的硬核力量,引领高速互联新时代
  • Appium + Node.js 测试全流程
  • 最接近的三数之和
  • Java 基础知识填空题(共 10 题)
  • 6.ref创建对象类型的响应式数据
  • FPGA实现VESA DSC编码功能
  • 【游戏项目】大型项目Git分支策略与开发流程设计构想
  • 无人机智能运行系统技术解析
  • 为进行性核上性麻痹患者定制:饮食健康指南
  • 全球首个体重管理AI大模型“减单”发布,学AI大模型来近屿智能
  • CMake指令: add_sub_directory以及工作流程