当前位置: 首页 > backend >正文

AdaFactor Optimizer 大模型训练优化器简介

文章目录

      • AdaFactor Optimizer 简介
      • 核心特点
      • 数学原理
      • 实际应用
      • 代码示例(PyTorch)
      • 总结

AdaFactor Optimizer 简介

AdaFactor 是一种用于训练深度学习模型的优化器,由谷歌在 2018 年提出,来自论文:

Noam M. Shazeer and Mitchell Stern. Adafactor: Adaptive learning rates with sublinear memory cost. ArXiv, abs/1804.04235, 2018.

它旨在解决传统自适应优化器(如 Adam)在训练大型模型时面临的一些问题,特别是内存消耗大和泛化能力可能受限的问题。

核心特点

  1. 降低内存消耗

    • 传统自适应优化器的问题:像 Adam 这样的优化器,会为每个模型参数维护两个一阶矩估计(均值)和二阶矩估计(未中心化的方差)的统计量。对于大型模型,参数数量庞大,这些统计量会占用大量内存。例如,一个拥有 10 亿参数的模型,使用 Adam 优化器时,仅存储这些统计量就需要消耗大量 GPU 内存。
    • AdaFactor 的改进:AdaFactor 采用了一种低秩近似的方法来存储二阶矩估计。它将二阶矩估计矩阵分解为两个低秩矩阵的乘积,从而大大减少了内存占用。具体来说,它将原本需要存储的 n × m n \times m n×m 的矩阵( n n n m m m 分别是参数矩阵的行数和列数)近似为两个较小的矩阵的乘积,使得内存消耗从 O ( n m ) O(nm) O(nm) 降低到 O ( n + m ) O(n + m) O(n+m)
  2. 自适应学习率调整

    • 原理:AdaFactor 继承了自适应优化器的优点,能够根据参数的历史梯度信息自动调整学习率。它通过计算一阶矩估计和二阶矩估计来动态地缩放梯度,使得不同参数能够以合适的步长进行更新。
    • 优势:与固定学习率的优化器相比,AdaFactor 可以更快地收敛,并且在处理不同尺度的特征时更加稳定。例如,在训练神经网络时,不同层的参数可能具有不同的梯度尺度,AdaFactor 能够自动适应这些差异,提高训练效率。
  3. 避免学习率过早衰减

    • 传统优化器的不足:一些自适应优化器在训练过程中可能会出现学习率过早衰减的问题,导致模型在后期训练中收敛速度变慢,甚至无法达到更好的性能。
    • AdaFactor 的解决方案:AdaFactor 采用了一种基于相对变化的学习率调整策略。它通过比较当前梯度和历史梯度的相对变化来决定是否调整学习率,而不是简单地依赖固定的衰减策略。这样可以避免学习率过早衰减,使模型在训练后期仍然能够保持较好的学习效率。

数学原理

  • 一阶矩估计:AdaFactor 计算梯度的一阶矩估计 m t m_t mt,类似于 Adam 中的做法,但使用了一种指数移动平均的方式进行更新:
    m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t - 1} + (1 - \beta_1)g_t mt=β1mt1+(1β1)gt
    其中, g t g_t gt 是当前时刻的梯度, β 1 \beta_1 β1 是一个超参数,控制着一阶矩估计的衰减速度。

  • 二阶矩估计的低秩近似:AdaFactor 将二阶矩估计 V t V_t Vt 分解为两个低秩矩阵 R t R_t Rt C t C_t Ct 的乘积,即 V t ≈ R t C t T V_t \approx R_t C_t^T VtRtCtT。在更新过程中,分别对 R t R_t Rt C t C_t Ct 进行更新:
    R t = β 2 R t − 1 + ( 1 − β 2 ) g t ⊙ g t ⋅ 1 max ⁡ ( 1 , row_norm ( R t − 1 ) ) R_t = \beta_2 R_{t - 1} + (1 - \beta_2)g_t \odot g_t \cdot \frac{1}{\max(1, \text{row\_norm}(R_{t - 1}))} Rt=β2Rt1+(1β2)gtgtmax(1,row_norm(Rt1))1
    C t = β 2 C t − 1 + ( 1 − β 2 ) g t ⊙ g t ⋅ 1 max ⁡ ( 1 , col_norm ( C t − 1 ) ) C_t = \beta_2 C_{t - 1} + (1 - \beta_2)g_t \odot g_t \cdot \frac{1}{\max(1, \text{col\_norm}(C_{t - 1}))} Ct=β2Ct1+(1β2)gtgtmax(1,col_norm(Ct1))1
    其中, ⊙ \odot 表示逐元素相乘, β 2 \beta_2 β2 是另一个超参数,控制着二阶矩估计的衰减速度, row_norm \text{row\_norm} row_norm col_norm \text{col\_norm} col_norm 分别表示对矩阵的行和列进行归一化操作。

  • 参数更新:根据一阶矩估计和二阶矩估计的低秩近似,计算参数的更新量 Δ θ t \Delta \theta_t Δθt
    Δ θ t = − m t V t + ϵ \Delta \theta_t = -\frac{m_t}{\sqrt{V_t} + \epsilon} Δθt=Vt +ϵmt
    其中, ϵ \epsilon ϵ 是一个很小的常数,用于避免除以零。然后,使用更新量对参数进行更新:
    θ t + 1 = θ t + Δ θ t \theta_{t + 1} = \theta_t + \Delta \theta_t θt+1=θt+Δθt

实际应用

  • 大型语言模型训练:在训练像 BERT、GPT 这样的大型语言模型时,AdaFactor 可以显著减少内存消耗,使得在有限的硬件资源下能够训练更大的模型。例如,在训练一个拥有数十亿参数的语言模型时,使用 AdaFactor 优化器可以将内存占用降低数倍,从而允许在单个 GPU 或较少的 GPU 集群上完成训练。
  • 计算机视觉任务:在图像分类、目标检测等计算机视觉任务中,AdaFactor 也能够提高训练效率和模型性能。特别是在处理高分辨率图像和复杂模型结构时,其降低内存消耗的优势更加明显。

代码示例(PyTorch)

import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from transformers import AdaFactor, AdaFactorOptimizer  # 假设使用 transformers 库中的 AdaFactor# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(100, 10)def forward(self, x):return self.fc(x)# 创建模型实例
model = SimpleModel()# 定义输入数据和标签
inputs = torch.randn(32, 100)
labels = torch.randint(0, 10, (32,))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 使用 AdaFactor 优化器
optimizer = AdaFactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False)# 训练循环
for epoch in range(10):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')

总结

AdaFactor 优化器通过低秩近似的方法降低了内存消耗,同时保留了自适应优化器的优点,能够自适应地调整学习率,避免学习率过早衰减。在训练大型深度学习模型时,AdaFactor 是一种非常有效的优化器选择,特别是在内存资源有限的情况下。它在大型语言模型和计算机视觉等领域都有广泛的应用前景。

http://www.xdnf.cn/news/12602.html

相关文章:

  • 多线程2(Thread)
  • C++算法-动态规划2
  • 前端基础之《Vue(19)—状态管理》
  • 73 LV的使用(XFS文件系统)
  • CMA软件产品测试报告在哪申请?
  • Dify+Ollama搭建本地知识库
  • C/C++ 中附加包含目录、附加库目录与附加依赖项详解
  • 高精度滚珠导轨在医疗设备中的多元应用场景
  • 江科大读写内部flash到hal库实现
  • STTT(IF:40.8) 清华大学常智杰团队完成雾化外泌体治疗肺纤维化的I期临床试验
  • python学习打卡day46
  • DRV8833 电机控制芯片
  • STM32定时器的种类作用
  • 惠斯通电桥温度补偿优化解决方案
  • 《架构即未来》笔记
  • Cesium等高线
  • 新版双紫擒龙、紫紫红黄、动能二号源码指标源码公式讲解
  • 基于SmartPlayer的超低延迟RTSP播放器全平台开发实录
  • 【GESP真题解析】第 14 集 GESP 三级 2024 年 9 月编程题 1:平衡序列
  • MajicTryOn(基于wanvideo的虚拟试穿项目)
  • 单图像生成3D动画模型TripoSR的部署过程
  • 局域网聊天室系统的设计与实现【源码+文档】
  • 储能方案设计:鹧鸪云模拟软件优势尽显
  • 文件对话框
  • daz3d + PBRSkin (MDL)+ SSS
  • 【国产8K 50P小型化广播级摄像机X2023央视总台春晚】多图预警
  • MySQL基础(五)事务、DCL权限控制、视图、同义词、索引及练习
  • 学习数字孪生,为你的职业发展开辟新赛道
  • 港股TRS交易系统开发:跨境资本的精密调度引擎
  • Beckhoff(倍福)PLC 顺控程序转换条件解读