当前位置: 首页 > news >正文

Megatron-LM(模型并行)

Megatron-LM: Training Multi-Billion Parameter Language Models Using
Model Parallelism

1. 技术设计原则

Megatron-LM 提出轻量级层内模型并行,无需定制编译器或修改框架,仅通过在 PyTorch 原生代码中插入少量通信操作(如all-reduce)实现,且与流水线模型并行正交互补,可灵活组合。

2.背景:矩阵分块计算

参考:https://www.bilibili.com/video/BV1HdXtY9EuF/?share_source=copy_web&vd_source=0f3d85b09673431159069a2a9a3da50c

在这里插入图片描述
矩阵XXXYYY相乘,即计算XYXYXY,有两种分块运算方式:

  1. YYY拆分为[Y1,Y2][Y_1,Y_2][Y1,Y2]XXX不变
    XmnYnk=Xmn[Y1nk2,Y2nk2]=[XY1,XY2]X^{mn}Y^{nk}=X^{mn}[Y_1^{n\frac{k}{2}},Y_2^{n\frac{k}{2}}]=[XY_1,XY_2]XmnYnk=Xmn[Y1n2k,Y2n2k]=[XY1,XY2]
  2. YYY拆分为 [Y1Y2]\begin{bmatrix} Y_1 \\ Y_2 \end{bmatrix}[Y1Y2],把XXX拆分为[X1,X2][X_1,X_2][X1,X2]
    XmnYnk=[X1mn2,X2mn2]×[Y1n2kY2n2k]=(X1Y1)mk+(X2Y2)mkX^{mn}Y^{nk}= \begin{bmatrix}X_1^{m \frac{n}{2}},X_2^{m \frac{n}{2}} \end{bmatrix} \times \begin{bmatrix} Y_1^{\frac{n}{2} k} \\ Y_2^{\frac{n}{2} k} \end{bmatrix} = (X_1Y_1)^{mk} + (X_2Y_2)^{m k}XmnYnk=[X1m2n,X2m2n]×[Y12nkY22nk]=(X1Y1)mk+(X2Y2)mk

3. 关键模块并行化实现

这一部分的图解为作者自己根据理解画的,如果有错误请指正

(1)前馈网络层(MLP)

公式:FFN(X)=σ(XA)BFFN(X)=\sigma(XA)BFFN(X)=σ(XA)B
设序列长度为lll,隐藏层维度为ddd,前馈网络的隐藏层维度为dFFNd_{FFN}dFFNA∈Rd×dFFN,B∈RdFFN×dA \in \mathbb{R}^{d \times d_{FFN}}, B \in \mathbb{R}^{d_{FFN} \times d}ARd×dFFN,BRdFFN×d
张量并行-前馈网络层

权重矩阵拆分策略:
第一层线性层的权重矩阵按列拆分(A=[A1,A2]A=[A_1,A_2]A=[A1,A2]

  • 使 GeLU 非线性激活可在各 GPU 上独立计算,避免中间同步;
  • 即使得GeLU(XA)=[GeLU(XA1),⋯,GeLU(XAn)]\text{GeLU}(XA)=[ \text{GeLU}(XA_1),\cdots ,\text{GeLU}(XA_n)]GeLU(XA)=[GeLU(XA1),,GeLU(XAn)]

第二层线性层的权重矩阵按行拆分,直接接收 GeLU 输出

  • 在前向传播时仅需对第二层的输出做一次 All-Reduce 聚合输出。
  • 在反向传播时仅需在返回到输入时做一次 All-Reduce 聚合梯度。

通信优化:整个 MLP 模块仅需 2 次 all-reduce 操作(前向1次、反向1次),无额外同步点。

(2)多头注意力(Multi-Head Attention)模块

在这里插入图片描述
注意力头拆分:

  • 将 Q、K、V 对应的权重矩阵按拆分,每个 GPU 负责部分注意力头的计算,无需中间通信;
    注意力输出层权重按拆分,直接接收并行计算结果,仅需在反向传播聚合梯度。

优势:充分利用注意力头的天然并行性,每个 GPU 仅处理部分头的计算,降低单设备内存压力。

(3)输入层与输出层优化

张量并行-输入层
输入嵌入:

  • 按词汇表维度列拆分嵌入矩阵(E=[E1,E2]E=[E_1,E_2]E=[E1,E2]
  • 通过 g 算子(前向 all-reduce)聚合结果,避免单 GPU 存储完整词汇表。

在这里插入图片描述

输出层与损失计算:

  • 融合最终线性层的输出与交叉熵损失计算,直接在各 GPU 上计算局部损失后聚合
  • 无需传输大规模 logits,减少通信量从b×s×vb×s×vb×s×vb×sb×sb×s
  • bbb 为批次大小、sss 为序列长度、vvv 为词汇表大小

4.混合并行策略(模型+数据并行)

混合并行策略
GPU分组:

  • 将 GPU 划分为模型并行组(如8个GPU一组,共同承载一个模型)和数据并行组(不同模型并行组中同位置 GPU 组成,负责梯度同步)
  • 总 GPU 数 = 模型并行度 × 数据并行度
http://www.xdnf.cn/news/1405351.html

相关文章:

  • 2025 年 AI 发展十大预测:多模态融合、边缘 AI 普及将成核心增长点
  • Redis数据类型概览:除了五大基础类型还有哪些?
  • 【适度精简】Windows 7 旗舰版-emmy精简系统
  • SpringAI应用开发工程师高阶面试剧本与知识点全解析(含RAG、多租户、流式推理、企业落地场景)
  • leetcode2(移除元素)
  • windows32位下载谷歌浏览器的地址
  • Twitter舆情裂变链:指纹云手机跨账号协同机制提升互动率200%
  • 大数据在UI前端的应用深化研究:用户行为数据的跨平台关联分析
  • 优化器全指南:从原理到调优实战
  • DrissionPage 实战:高效爬取网页数据并保存为 CSV 的全流程解析
  • 什么是雪花算法
  • Western Blot 样本制备完整流程:从细胞 / 组织到变性样品的关键步骤与细节
  • Selenium自动化测试快速入门指南
  • 玄机靶场 | 第五届红明谷-异常行为溯源
  • MCP进阶指南:如何挑选最适合你的AI助手“装备“
  • [光学原理与应用-332]:ZEMAX - 序列模式与非序列模式的本质、比较
  • JavaScript 中的 this 关键字
  • Python远程文件管理移动端适配与跨平台优化实战
  • 【自记】MaxCompute 中 对于“数据量大、耗时久、非实时”任务的设置建议
  • Linux 下 Docker 容器部署指南(Java + Redis 示例)
  • 2025年水库单北斗GNSS变形监测TOP3系统推荐榜单
  • C++ 之 【map和set的模拟实现】(只涉及map和set的插入、迭代器以及map的operator[]函数)
  • 使用 JavaScript 构建 RAG(检索增强生成)库:原理与实现
  • TechPowerUp GPU-Z中文版:专业显卡检测工具
  • 多教师语言感知知识蒸馏:提升多语种语音情绪识别的新方法
  • FPGA 实现FOC 无刷电机控制器
  • 数字化赋能,鹧鸪云重塑光伏电站资产管理新范式
  • DDR5 介绍
  • C/C++:AddressSanitizer内存检测工具
  • 基于单片机甲醛浓度检测报警系统Proteus仿真(含全部资料)