lesson05-手写数据问题案例实战(理论+代码)
在本篇文章中,我们将详细探讨如何使用简单的神经网络模型对手写数字进行识别。我们将从数据准备开始,介绍整个流程直至模型推理。
一、准备数据集
首先,我们需要一个合适的数据集来进行训练和测试。这里我们选择的是著名的 MNIST 数据集,它包含了大量的手写数字图像(0-9),每个数字有7000张图像,总共60,000张用于训练,10,000张用于测试。
二、没有深度学习只有映射
在这个阶段,我们将不依赖于复杂的深度学习架构,而是通过简单的线性映射来实现函数逼近。输入是一个28x28像素的灰度图像,展平后形成一个长度为784的一维向量X。
三、损失函数
为了衡量预测值与真实值之间的差距,我们需要定义一个损失函数。在这里,我们选择了欧几里得距离作为损失函数,计算预测输出H3与实际标签Y之间的差异。
四、非线性因子
为了让模型具有更强的表现力,我们引入了非线性激活函数ReLU。这有助于捕捉输入数据中的复杂模式,而非简单地执行线性变换。
五、梯度下降
为了最小化损失函数,我们采用梯度下降算法调整权重和偏置项。目标是最小化预测值与真实值之间的差异。
六、推理
最后,在完成模型训练之后,我们可以用该模型对新的输入进行预测。对于给定的新输入X1,通过前向传播得到预测结果,并根据最大概率确定最终的分类结果。
总结
本文简要介绍了如何构建一个基本的神经网络用于手写数字识别任务。这个例子相对简单,但它涵盖了机器学习项目的基本步骤,包括数据预处理、模型设计、训练过程以及最终的推理应用。希望这篇文章能够帮助初学者理解并入门这一领域。
代码案例:
🔍 一、导入库
import torch
from torch import nn
from torch.nn import functional as F
from torch import optimimport torchvision
from matplotlib import pyplot as pltfrom utils import plot_image, plot_curve, one_hot
torch
: PyTorch 的核心库。nn
: 提供神经网络层,如线性层、卷积层等。F
: 包含激活函数、损失函数等。optim
: 提供优化器,如 SGD、Adam 等。torchvision
: 提供常用数据集(如 MNIST)和图像变换工具。matplotlib.pyplot
: 用于绘图。utils
: 自定义辅助函数:plot_image
: 显示图像样本。plot_curve
: 绘制训练 loss 曲线。one_hot
: 将类别标签转换为 one-hot 编码。
📦 二、设置 batch_size 并加载数据集
batch_size = 512
加载训练集
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
- 使用
DataLoader
加载 MNIST 数据集。 transform
对图像进行预处理:ToTensor()
: 将图像转为 [0,1] 范围内的张量。Normalize((0.1307,), (0.3081))
: 对灰度图做标准化(均值和标准差来自 MNIST 训练集统计)。
shuffle=True
: 每个 epoch 开始时打乱数据。
加载测试集
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)
测试集不需要打乱顺序。
查看一个 batch 的数据结构
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')
x.shape
:[512, 1, 28, 28]
→ 表示 batch_size=512,单通道(灰度图),28x28 像素。y.shape
:[512]
→ 每个样本对应的数字标签(0~9)。plot_image()
:可视化一批次图像和对应标签。
🧠 三、构建神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xw+bself.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return x
- 定义了一个三层全连接神经网络:
- 输入层:28×28 = 784 维。
- 隐藏层1:256 个神经元。
- 隐藏层2:64 个神经元。
- 输出层:10 个神经元(对应 10 个数字类别)。
- 使用 ReLU 激活函数。
- 模型结构简单但足够完成 MNIST 分类任务。
⚙️ 四、定义优化器
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
- 创建模型实例
net
。 - 使用随机梯度下降(SGD)作为优化器,学习率
lr=0.01
,动量momentum=0.9
可以加速收敛。
📈 五、训练模型
train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):x = x.view(x.size(0), 28*28) # 展平输入out = net(x) # 前向传播y_onehot = one_hot(y) # 标签 one-hot 编码loss = F.mse_loss(out, y_onehot) # 使用均方误差损失optimizer.zero_grad() # 清空梯度loss.backward() # 反向传播optimizer.step() # 参数更新train_loss.append(loss.item())if batch_idx % 10 == 0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)
- 前向传播:输入图像展平后送入网络,输出预测结果。
- 标签编码:将整数标签转换为 one-hot 向量,便于计算损失。
- 损失函数:使用均方误差(MSE)代替交叉熵损失(虽然不太推荐,但对简单任务也能工作)。
- 反向传播:计算梯度并更新参数。
- 记录 loss:绘制训练曲线,观察模型是否在学习。
🧪 六、测试模型性能
total_correct = 0
for x, y in test_loader:x = x.view(x.size(0), 28*28)out = net(x)pred = out.argmax(dim=1) # 取最大概率的类别作为预测correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)
- 在测试集上评估模型准确率。
- 使用
argmax()
获取预测类别。 eq()
判断预测与真实标签是否一致,求和得到正确数。- 最终输出测试准确率。
🖼️ 七、可视化测试结果
x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')
- 取出一批测试图像,用模型预测,显示预测结果。
✅ 总结
模块 | 功能 |
---|---|
数据加载 | 使用 DataLoader + transforms 加载并预处理 MNIST 数据 |
模型定义 | 构建三层全连接网络,使用 ReLU 激活函数 |
损失函数 | 使用 MSE Loss(建议后期改为 CrossEntropyLoss) |
优化器 | 使用带动量的 SGD |
训练流程 | 前向传播、计算损失、反向传播、更新参数 |
测试流程 | 评估模型准确率,并可视化预测结果 |
完整代码demo:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optimimport torchvision
from matplotlib import pyplot as pltfrom utils import plot_image, plot_curve, one_hotbatch_size = 512# step1. load dataset
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=False)x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')class Net(nn.Module):def __init__(self):super(Net, self).__init__()# xw+bself.fc1 = nn.Linear(28*28, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):# x: [b, 1, 28, 28]# h1 = relu(xw1+b1)x = F.relu(self.fc1(x))# h2 = relu(h1w2+b2)x = F.relu(self.fc2(x))# h3 = h2w3+b3x = self.fc3(x)return xnet = Net()
# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)train_loss = []for epoch in range(3):for batch_idx, (x, y) in enumerate(train_loader):# x: [b, 1, 28, 28], y: [512]# [b, 1, 28, 28] => [b, 784]x = x.view(x.size(0), 28*28)# => [b, 10]out = net(x)# [b, 10]y_onehot = one_hot(y)# loss = mse(out, y_onehot)loss = F.mse_loss(out, y_onehot)optimizer.zero_grad()loss.backward()# w' = w - lr*gradoptimizer.step()train_loss.append(loss.item())if batch_idx % 10==0:print(epoch, batch_idx, loss.item())plot_curve(train_loss)
# we get optimal [w1, b1, w2, b2, w3, b3]total_correct = 0
for x,y in test_loader:x = x.view(x.size(0), 28*28)out = net(x)# out: [b, 10] => pred: [b]pred = out.argmax(dim=1)correct = pred.eq(y).sum().float().item()total_correct += correcttotal_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:', acc)x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')