当前位置: 首页 > java >正文

Adam优化器

故事背景:你是一个想减肥的小明

目标:把体重从 80kg 减到 60kg(这相当于模型要找到最优参数)

1. 普通减肥法(SGD优化器)

  • 方法:每天称体重,发现重了就少吃,发现轻了就多吃点。

  • 问题

    • 今天体重涨了1斤,明天立刻绝食 → 反应过激(步子太大)

    • 连续三天体重没变,以为减不动了 → 轻易放弃(步子太小)

❌ 缺点:要么瞎折腾,要么躺平摆烂。

 

2. 升级版减肥法(Adam登场!)

Adam像你的 智能健身私教,它做两件事:

🔹 第一招:看趋势(动量 Momentum)
  • 例子
    教练不会因为某天体重波动就骂你,而是看 过去一周的趋势

    • 如果连续5天体重下降 → 说明方法有效,继续保持节奏

    • 如果连续3天体重反弹 → 说明吃多了,要加大运动量

  • 作用避免被单次波动带偏,稳步前进。

🔹 第二招:个性化调整(自适应学习率)
  • 例子
    教练发现:

    • 你的 大腿 特别难减(需要猛练)→ 给大腿 开大号运动量

    • 你的 手臂 容易瘦(练猛了会皮松)→ 给手臂 调小运动量

  • 作用不同部位(参数)用不同力度,精准打击!

3. Adam私教的工作流程

  1. 记录历史

    • 记下你每天每个部位的围度变化(存动量+历史梯度)

  2. 分析趋势

    • 大腿最近减得慢 → 下次重点加练(加大更新步长)

    • 腰围降太快 → 下次轻点练(减小更新步长)

  3. 动态调整

    • 根据每个部位的“顽固程度”,定制训练计划(自适应学习率)

✅ 效果:不蛮干、不放弃,用科学方法逼近目标体重!

Adam在AI训练中的真实作用

你的减肥Adam优化器AI训练目标
体重、腰围、腿围模型参数(Parameters)让模型预测更准
每天围度变化梯度(Gradients)参数该往哪调、调多少
看一周趋势动量(Momentum)减少震荡,稳定方向
按部位调整力度自适应学习率不同参数不同更新速度

 

为什么大模型都用Adam?

  • 处理混乱数据:像忽胖忽瘦的体重,Adam能排除噪音干扰

  • 应对复杂模型:像不同部位不同策略,Adam给每个参数“定制方案”

  • 更快见效:比蛮干(SGD)少走弯路,训练速度提升10倍⏰

💡 一句话总结:
Adam = 趋势跟踪器 + 个性化调节器
它让AI模型像科学减肥一样——不瞎折腾、精准高效地走向最优解!

 

一、Adam算法核心数学原理

1. 参数更新公式

设模型参数为 θθ,目标函数梯度为 gtgt​,学习率为 αα,Adam更新规则如下:

\begin{align}
m_t &= \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t & \text{(一阶动量)} \\
v_t &= \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 & \text{(二阶动量)} \\
\hat{m}_t &= \frac{m_t}{1 - \beta_1^t} & \text{(偏差修正)} \\
\hat{v}_t &= \frac{v_t}{1 - \beta_2^t} \\
\theta_{t} &= \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon
\end{align}

 

2. 超参数作用
参数典型值数学意义工程影响
β1β1​0.9一阶动量衰减率控制梯度方向记忆强度
β2β2​0.999二阶动量衰减率控制梯度幅值自适应
ϵϵ1e-8数值稳定性常数防止除零错误
αα1e-3~5e-5基础学习率决定更新步长基准

二、混合精度训练中的关键实现

1. 内存布局优化

 

# FP16参数 + FP32 Master Copy (混合精度标准实践)
model_param_fp16 = torch.randn(10, dtype=torch.float16)  # 前向/反向传播使用
master_param_fp32 = model_param_fp16.float()             # 优化器更新使用# Adam状态存储 (FP32)
momentum = torch.zeros_like(master_param_fp32)            # 一阶动量
variance = torch.zeros_like(master_param_fp32)            # 二阶动量

 2. 更新步骤伪代码

 

def adam_update(param_fp16, master_fp32, m, v, grad_fp16, t, alpha=0.001):# 梯度转FP32 (避免精度损失)grad_fp32 = grad_fp16.float()# 更新动量m = beta1*m + (1-beta1)*grad_fp32v = beta2*v + (1-beta2)*(grad_fp32**2)# 偏差修正m_hat = m / (1 - beta1**t)v_hat = v / (1 - beta2**t)# 参数更新 (FP32空间)master_fp32 -= alpha * m_hat / (torch.sqrt(v_hat) + eps)# 同步到FP16param_fp16.copy_(master_fp32.half())

三、内存开销分析(关键题解)

设模型参数量为 PP:

组件数据类型内存占用总大小
ParametersFP162字节/参数2P2P
GradientsFP162字节/参数2P2P
Optimizer StatesFP3212字节/参数12P
├─ MomentumFP324字节
├─ VarianceFP324字节
└─ Master ParamsFP324字节
LossFP324字节可忽略

✅ 结论:优化器状态(12P) > 梯度(2P) + 参数(2P),Optimi zer states是内存瓶颈

四、工程级优化策略

1. 内存压缩技术
  • 8-bit Adam (e.g. bitsandbytes库)

    • 动量/方差用FP8存储 → 内存降至 8P

    • 更新时转FP32计算

  • 分片优化器 (ZeRO Stage 1)

    • 优化器状态在多卡间分区 → 每卡内存降至 12P/N12P/N

 2. 数值稳定性增强

# 梯度裁剪 + Adam (防止二阶动量爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)# 学习率暖身 (避免早期v_hat过小)
if step < warmup_steps:lr = base_lr * (step / warmup_steps)
3. 硬件适配优化
# NVIDIA A100 混合精度训练命令 (AMP + Adam)
torchrun --nproc_per_node=8 \--mixed_precision fp16 \--opt_mode adam \train.py# 昇腾910 特定优化 (华为CANN)
from npu.contrib.optimizer import AdamWNPU
optimizer = AdamWNPU(model.parameters(), lr=0.001)

五、与其他优化器的对比

特性AdamSGD+MomentumAdaGradLAMB
自适应学习率
方向+尺度自适应
内存占用/参数12B4B4B12B
超参数敏感性
大模型训练适用性主流需精细调参淘汰分布式常用

六、最佳实践建议

学习率策略 

# Cosine衰减 + 线性暖身
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6
)
  1. 参数初始化

    • 与学习率解耦:使用 kaiming_normal_ 初始化

    • 层缩放:对残差块输出乘 0.50.5​

  2. 故障排查

    • 梯度爆炸:检查 β2β2​ 是否接近1(推荐>0.99)

    • 收敛失败:尝试 ϵϵ 从1e-6调整至1e-8

前沿方向

  • Sophia优化器 (2023):用Hessian对角估计替代二阶动量,内存降至8P

  • CAME (2023):混合INT8/FP16状态存储,千亿模型内存降40%

http://www.xdnf.cn/news/15604.html

相关文章:

  • IMU噪声模型
  • 【数据结构】链表(linked list)
  • PostgreSQL 中的 pg_trgm 扩展详解
  • 命名实体识别15年研究全景:从规则到机器学习的演进(1991-2006)
  • Python 基础语法与数据类型(十三) - 实例方法、类方法、静态方法
  • SAP-ABAP:SAP的‘cl_http_utility=>escape_url‘对URL进行安全编码方法详解
  • Linux Swap区深度解析:为何禁用?何时需要?
  • 【程序地址空间】虚拟地址与页表转化
  • 基于Rust游戏引擎实践(Game)
  • 线上项目https看不了http的图片解决
  • 在分布式系统中,如何保证缓存与数据库的数据一致性?
  • docker 容器无法使用dns解析域名异常问题排查
  • springboot 整合spring-kafka客户端:SASL_SSL+PLAINTEXT方式
  • LeetCode20
  • 边界路由器
  • Baumer工业相机堡盟工业相机如何通过YoloV8模型实现人物识别(C#)
  • 如何做好DNA-SIP?
  • Redis完全指南:从基础到实战(含缓存问题、布隆过滤器、持久化及Spring Boot集成)
  • 数据结构 栈(2)--栈的实现
  • 4.PCL点云的数据结构
  • 「Chrome 开发环境快速屏蔽 CORS 跨域限制详细教程」*
  • springboot跨域问题 和 401
  • 人工智能基础知识笔记十四:文本转换成向量
  • Android 实现:当后台数据限制开启时,仅限制互联网APN。
  • 什么是“数据闭环”
  • Docker-Beta?ollama的完美替代品
  • MySQL高可用集群架构:主从复制、MGR与读写分离实战
  • TDengine 的可视化数据库操作工具 taosExplorer(安装包自带)
  • VMware Workstation Pro 17下载安装
  • VR全景园区:开启智慧园区新时代