Megatron-LM(模型并行)
Megatron-LM: Training Multi-Billion Parameter Language Models Using
Model Parallelism
1. 技术设计原则
Megatron-LM 提出轻量级层内模型并行,无需定制编译器或修改框架,仅通过在 PyTorch 原生代码中插入少量通信操作(如all-reduce)实现,且与流水线模型并行正交互补,可灵活组合。
2.背景:矩阵分块计算
参考:https://www.bilibili.com/video/BV1HdXtY9EuF/?share_source=copy_web&vd_source=0f3d85b09673431159069a2a9a3da50c
矩阵XXX和YYY相乘,即计算XYXYXY,有两种分块运算方式:
- 把YYY按列拆分为[Y1,Y2][Y_1,Y_2][Y1,Y2],XXX不变
● XmnYnk=Xmn[Y1nk2,Y2nk2]=[XY1,XY2]X^{mn}Y^{nk}=X^{mn}[Y_1^{n\frac{k}{2}},Y_2^{n\frac{k}{2}}]=[XY_1,XY_2]XmnYnk=Xmn[Y1n2k,Y2n2k]=[XY1,XY2] - 把YYY按行拆分为 [Y1Y2]\begin{bmatrix} Y_1 \\ Y_2 \end{bmatrix}[Y1Y2],把XXX按列拆分为[X1,X2][X_1,X_2][X1,X2]
● XmnYnk=[X1mn2,X2mn2]×[Y1n2kY2n2k]=(X1Y1)mk+(X2Y2)mkX^{mn}Y^{nk}= \begin{bmatrix}X_1^{m \frac{n}{2}},X_2^{m \frac{n}{2}} \end{bmatrix} \times \begin{bmatrix} Y_1^{\frac{n}{2} k} \\ Y_2^{\frac{n}{2} k} \end{bmatrix} = (X_1Y_1)^{mk} + (X_2Y_2)^{m k}XmnYnk=[X1m2n,X2m2n]×[Y12nkY22nk]=(X1Y1)mk+(X2Y2)mk
3. 关键模块并行化实现
这一部分的图解为作者自己根据理解画的,如果有错误请指正
(1)前馈网络层(MLP)
公式:FFN(X)=σ(XA)BFFN(X)=\sigma(XA)BFFN(X)=σ(XA)B
设序列长度为lll,隐藏层维度为ddd,前馈网络的隐藏层维度为dFFNd_{FFN}dFFN,A∈Rd×dFFN,B∈RdFFN×dA \in \mathbb{R}^{d \times d_{FFN}}, B \in \mathbb{R}^{d_{FFN} \times d}A∈Rd×dFFN,B∈RdFFN×d。
权重矩阵拆分策略:
第一层线性层的权重矩阵按列拆分(A=[A1,A2]A=[A_1,A_2]A=[A1,A2])
- 使 GeLU 非线性激活可在各 GPU 上独立计算,避免中间同步;
- 即使得GeLU(XA)=[GeLU(XA1),⋯,GeLU(XAn)]\text{GeLU}(XA)=[ \text{GeLU}(XA_1),\cdots ,\text{GeLU}(XA_n)]GeLU(XA)=[GeLU(XA1),⋯,GeLU(XAn)]
第二层线性层的权重矩阵按行拆分,直接接收 GeLU 输出
- 在前向传播时仅需对第二层的输出做一次 All-Reduce 聚合输出。
- 在反向传播时仅需在返回到输入时做一次 All-Reduce 聚合梯度。
通信优化:整个 MLP 模块仅需 2 次 all-reduce 操作(前向1次、反向1次),无额外同步点。
(2)多头注意力(Multi-Head Attention)模块
注意力头拆分:
- 将 Q、K、V 对应的权重矩阵按列拆分,每个 GPU 负责部分注意力头的计算,无需中间通信;
注意力输出层权重按行拆分,直接接收并行计算结果,仅需在反向传播聚合梯度。
优势:充分利用注意力头的天然并行性,每个 GPU 仅处理部分头的计算,降低单设备内存压力。
(3)输入层与输出层优化
输入嵌入:
- 按词汇表维度列拆分嵌入矩阵(E=[E1,E2]E=[E_1,E_2]E=[E1,E2])
- 通过 g 算子(前向 all-reduce)聚合结果,避免单 GPU 存储完整词汇表。
输出层与损失计算:
- 融合最终线性层的输出与交叉熵损失计算,直接在各 GPU 上计算局部损失后聚合
- 无需传输大规模 logits,减少通信量从b×s×vb×s×vb×s×v至b×sb×sb×s
- bbb 为批次大小、sss 为序列长度、vvv 为词汇表大小
4.混合并行策略(模型+数据并行)
GPU分组:
- 将 GPU 划分为模型并行组(如8个GPU一组,共同承载一个模型)和数据并行组(不同模型并行组中同位置 GPU 组成,负责梯度同步)
- 总 GPU 数 = 模型并行度 × 数据并行度