当前位置: 首页 > ds >正文

如何理解神经网络训练的循环过程

正向传播 → 反向传播 → 参数更新
这个过程是一个完整的训练迭代,然后不断重复这一轮又一轮,直到模型收敛(损失函数趋于最小),这就是神经网络训练的核心流程。


🧠 详细解释整个循环过程:

我们把整个训练流程拆解成一个“学习周期”,它包括以下三个核心步骤:


1️⃣ 正向传播(Forward Propagation)

  • 输入数据:x
  • 计算每一层的输出,最终得到模型预测值 y(/hat)
  • 根据真实值 y 计算当前损失(Loss):比如均方误差、交叉熵等

目的

  • 得到当前参数下的预测结果
  • 为反向传播提供计算梯度所需的所有中间变量

2️⃣ 反向传播(Backward Propagation)

  • 利用链式法则(Chain Rule)从输出层开始,逐层向前计算每个参数(权重 W 、偏置 b)对损失的影响(即梯度)
  • 所有梯度保存下来供后续使用

目的

  • 知道每个参数怎么影响损失函数
  • 为参数更新提供依据

3️⃣ 参数更新(Parameter Update)

  • 使用梯度下降或其他优化算法(如 Adam、SGD with momentum)来更新参数:
    W : = W − α ⋅ ∂ L ∂ W W := W - \alpha \cdot \frac{\partial L}{\partial W} W:=WαWL
    b : = b − α ⋅ ∂ L ∂ b b := b - \alpha \cdot \frac{\partial L}{\partial b} b:=bαbL
    其中 α 是学习率

目的

  • 调整参数,使下一次预测更准确
  • 让损失函数逐步减小

🔁 循环进行:Epoch × Batch

整个流程通常是在两个嵌套循环中进行的:

for epoch in range(总训练轮数):for batch in 数据集:正向传播 → 计算预测和损失反向传播 → 计算梯度参数更新 → 梯度下降优化参数
  • epoch:完整遍历一遍所有训练数据
  • batch:每次使用的数据子集(mini-batch)

📈 整个训练过程中,我们期望看到的是:

阶段表现
刚开始训练损失大,预测不准
中期训练损失逐渐下降,预测变好
接近收敛损失稳定在一个较小值,模型表现良好

🎯 总结一句话:

神经网络的训练就是一个不断“预测 → 算误差 → 算梯度 → 调参数”的循环过程。通过一次次的正向传播、反向传播和参数更新,模型逐步学会如何做出更准确的预测,直到损失函数达到一个我们认为满意的最小值。


实践

使用 小批量梯度下降(Mini-batch Gradient Descent) 作为优化算法,以 MNIST 手写数字分类任务为例,构建一个简单的神经网络进行训练。

✅ 使用工具:PyTorch

  • 自动求导机制支持反向传播;
  • DataLoader 支持 mini-batch;
  • SGD 优化器实现小批量梯度下降;

🧠 代码实现

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# ================== 1. 数据准备 ==================
# 数据预处理:标准化 + 张量转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练数据和测试数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# ================== 2. 模型定义 ==================
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)  # 展平图像x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xmodel = SimpleNet()# ================== 3. 损失函数 ==================
criterion = nn.CrossEntropyLoss()# ================== 4. 小批量梯度下降优化器 ==================
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用 SGD,mini-batch 已在 DataLoader 中设置# ================== 5. 训练循环 ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)num_epochs = 5for epoch in range(num_epochs):model.train()running_loss = 0.0# DataLoader可以实现数据分批次,train_loader是已经分批后的数据,一个images就是一个batchfor images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 正向传播:输入 -> 输出outputs = model(images)loss = criterion(outputs, labels)# 反向传播:计算梯度optimizer.zero_grad()loss.backward()# 参数更新:使用小批量梯度下降(SGD)# 每次 optimizer.step() 就是一次参数更新 optimizer.step()# 记录了一个 epoch 中所有 batch 的 loss 总和,最后除以 batch 数量,得到平均 loss 并打印出来。running_loss += loss.item()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

部分参数的辅助理解可以看我另一篇文章:关于epoch、batch_size等参数含义,及optimizer.step()的含义及数学过程

📌 说明

部分内容
数据加载使用 DataLoader 构造 mini-batch(batch_size=64)
正向传播outputs = model(images)
损失计算loss = criterion(outputs, labels)
反向传播loss.backward()
参数更新optimizer.step(),使用的是 SGD 算法

✅ 小批量梯度下降的关键点

  • 每次更新参数只使用一个 batch 的样本(如 64 张图片)
  • 比随机梯度下降更稳定,比批量梯度下降更快
  • PyTorch 的 DataLoader + SGD 优化器天然支持 mini-batch

http://www.xdnf.cn/news/3421.html

相关文章:

  • 产品月报|睿本云4月产品功能迭代
  • MS31860T——8 通道串行接口低边驱动器
  • 制造业行业ERP软件部署全流程指南:从选型到维护,怎么做?
  • 多线程爬虫中实现线程安全的MySQL连接池
  • Java程序员如何设计一个高并发系统?
  • 基于MCP协议实现一个智能审核流程
  • 虚拟内存笔记(一)
  • AVPro Video加载视频文件并播放,可指定视频文件的位置、路径等参数
  • 运用ESS(弹性伸缩)技术实现服务能力的纵向扩展
  • foxmail时不时发送不了邮件问题定位解决过程
  • 苍穹外卖11
  • Windows查看和修改IP,IP互相ping通
  • 使用模块中的`XPath`语法提取非结构化数据
  • Learning vtkjs之ImageMarchingCubes
  • 100 个 NumPy 练习
  • centos安装nginx
  • 新手小白如何查找科研论文?
  • 2025深圳杯东三省数学建模竞赛选题建议+初步分析
  • 26个脑影像工具包合集分享:从预处理到SCI成图
  • 为什么定位关闭了还显示IP属地?
  • 软考中级-软件设计师 数据库(手写笔记)
  • TS类型体操练习
  • Rancher 2.6.3企业级容器管理平台部署实践
  • ESP32-C3 Secure Boot 使用多个签名 Key
  • FEKO许可管理
  • YOLO11改进-模块-引入跨模态注意力机制CMA 提高多尺度 遮挡
  • 6轴、智能、低功耗惯性测量单元BMI270及其OIS接口
  • 开源 RAG 框架对比:LangChain、Haystack、DSPy 技术选型指南
  • 常用矩阵求导
  • Java父类、子类实例初始化顺序详解