当前位置: 首页 > news >正文

关于epoch、batch_size等参数含义,及optimizer.step()的含义及数学过程

下面我们以 小批量梯度下降(Mini-batch Gradient Descent) 为例,来总结并解释:

  • epoch 的含义与作用
  • batch_size 的含义与影响
  • optimizer.step() 是什么及其背后的数学过程
  • 还可以加入 动量(Momentum)优化器 来扩展理解

🧠 一、基本概念总结(总览)

名称含义示例
Epoch整个训练集被完整遍历一次如果有 60,000 张图片,每次 batch 处理 64 张,则一个 epoch ≈ 938 次迭代
Batch Size每次输入给模型的数据样本数量如 batch_size = 64,表示每次用 64 张图像更新模型参数
Optimizer.step()执行一次参数更新,使用当前 batch 的梯度公式:$ W := W - \alpha \cdot g $,其中 $ g $ 是当前 batch 的梯度

📌 二、详细解释

1️⃣ Epoch

✅ 含义:

一个 epoch 表示模型已经“看完了”整个训练数据集一次。

🔁 训练过程:
  • 在每个 epoch 中,模型会对训练集中的所有样本进行前向传播和反向传播;
  • 每处理完一个 batch 就执行一次 optimizer.step()
  • 多个 batch 组成一个 epoch。
📊 示例:
  • 数据总数:60,000
  • batch_size = 64
  • 则一个 epoch ≈ 60000 / 64 ≈ 938 个 iteration

2️⃣ Batch Size

✅ 含义:

一次训练使用的样本数量。是 mini-batch GD 的核心特征。

⚖️ 影响:
特性小 batch size大 batch size
内存占用
训练速度慢但更频繁更新快但更新少
收敛稳定性更好(噪声帮助跳出局部极小)可能陷入局部最优
泛化能力通常更好略差
💡 推荐值:

常见的 batch size 值为 32、64、128、256。


3️⃣ optimizer.step()

✅ 含义:

这是 PyTorch 中的一个方法,用于根据当前计算出的梯度更新模型参数

🧮 数学表达式:

对于每个参数(如权重 $ W $ 和偏置 $ b $),执行:

W : = W − α ⋅ ∂ L ∂ W W := W - \alpha \cdot \frac{\partial L}{\partial W} W:=WαWL
b : = b − α ⋅ ∂ L ∂ b b := b - \alpha \cdot \frac{\partial L}{\partial b} b:=bαbL

其中:

  • $ \alpha $:学习率(learning rate)
  • $ \frac{\partial L}{\partial W} $:当前 batch 的梯度
🔄 调用流程:
loss.backward()        # 计算梯度
optimizer.step()       # 根据梯度更新参数

⚠️ 注意:在每次更新前要调用 optimizer.zero_grad() 清空旧梯度,避免累积。


🚀 三、扩展:加入动量(Momentum)优化器

✅ 动量的意义:

动量法是一种改进的梯度下降方法,它利用之前梯度的方向来加速收敛,减少震荡。

🧮 更新公式:

v t = β v t − 1 − α ⋅ ∇ J ( θ ) v_t = \beta v_{t-1} - \alpha \cdot \nabla J(\theta) vt=βvt1αJ(θ)
θ : = θ + v t \theta := \theta + v_t θ:=θ+vt

其中:

  • vt :当前时刻的速度(动量项)
  • β:动量系数,通常取 0.9
  • α:学习率
  • ▽J(θ):当前 batch 的梯度

注:
在这里插入图片描述

🧩 PyTorch 实现示例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

💡 使用动量后,optimizer.step() 不再只是简单的梯度乘以学习率更新,而是结合了历史梯度方向进行加权更新。


🎯 四、一句话总结

在小批量梯度下降中:

  • 一个 epoch 是对整个训练集的一次完整遍历;
  • batch_size 决定每次用多少样本计算梯度;
  • optimizer.step() 是执行参数更新的关键操作,对应一次梯度下降;
  • 加入 动量(Momentum) 可以提升收敛速度和稳定性。

http://www.xdnf.cn/news/235675.html

相关文章:

  • pinia实现数据持久化插件pinia-plugin-persist-uni
  • 10、属性和数据处理---c++17
  • 突破SQL注入字符转义的实战指南:绕过技巧与防御策略
  • 《Ultralytics HUB:开启AI视觉新时代的密钥》
  • Stack--Queue 栈和队列
  • 前端基础之《Vue(13)—重要API》
  • Dify Agent节点的信息收集策略示例
  • 【效率提升】Vibe Coding时代如何正确使用输入法:自定义短语实现Prompt快捷输入
  • windows系统 压力测试技术
  • Github开通第三方平台OAuth登录及Java对接步骤
  • ES使用之查询方式
  • 空域伦理与AI自主边界的系统建构
  • 《冰雪传奇点卡版》:第二大陆介绍!
  • Java 手写jdbc访问数据库
  • 代理脚本——爬虫
  • 【MySQL】索引特性
  • JGQ511机械振打袋式除尘器实验台装置设备
  • 鸿蒙的StorageLink
  • BT137-ASEMI机器人功率器件专用BT137
  • 【Hive入门】Hive性能优化:执行计划分析EXPLAIN命令的使用
  • 41 python http之requests 库
  • spring中的@Configuration注解详解
  • pytorch的cuda版本依据nvcc --version与nvidia-smi
  • 企业架构之旅(4):TOGAF ADM 中业务架构——企业数字化转型的 “骨架”
  • 永磁同步电机控制算法--单矢量模型预测电流控制MPCC
  • # 实现中文情感分析:基于TextRNN的模型部署与应用
  • 软件测试52讲学习分享:深入理解单元测试
  • BI平台是什么意思?一文讲清BI平台的具体应用!
  • AWTK:一键切换皮肤,打造个性化UI
  • 开源版禅道本地安装卸载备份迁移小白教程