【DeepSeek 学大模型推理】Fused Residual LayerNorm with Reduce-Scatter
这段描述涉及分布式训练中的张量并行(Tensor Parallelism, TP)技术,具体是注意力(Attention)层和多层感知机(MLP)层在不同设备(TP Group)间的通信优化。核心在于用 reduce_scatter
和 all_gather
这两个集合通信操作替代传统的 all_reduce
,从而降低通信开销并提升效率。
以下是详细解释:
-
背景:张量并行(TP)
- 在大模型训练中,单个设备(GPU)可能无法容纳整个模型(尤其是参数和激活值)。
- 张量并行 是将模型的单个层(如 Attention 层、MLP 层)的参数和计算水平拆分到多个设备(称为一个 TP Group 或 TP Domain)上。
- 每个设备只持有该层的一部分参数,并负责该层输入张量对应部分的计算。
- 为了得到完整的输出结果(供下一层使用或用于损失计算),这些设备之间需要进行通信来交换或聚合计算结果。
-
问题:传统通信 (
all_reduce
)- 在早期的 TP 实现(如 Megatron-LM 的早期版本)中,一个层(无论是 Attention 还是 MLP)计算完毕后,通常会在其 TP Group 内进行一次
all_reduce
操作。 all_reduce
做什么?- Reduce: 将所有设备上的输出张量求和(或其他操作,但通常是求和)。
- All: 将求和后的完整结果发送回所有设备。
- 缺点:
all_reduce
的通信量是world_size * tensor_size
(其中world_size
是 TP Group 大小)。虽然比reduce
+broadcast
高效,但对于大模型和大张量,其通信开销仍然巨大。
- 在早期的 TP 实现(如 Megatron-LM 的早期版本)中,一个层(无论是 Attention 还是 MLP)计算完毕后,通常会在其 TP Group 内进行一次
-
优化方案:
reduce_scatter
+all_gather
为了减少通信开销,现代框架(如 Megatron-LM, DeepSpeed)采用了更精细的通信策略,针对 Attention 层和 MLP 层的不同计算特性,分别使用不同的通信原语:-
reduce_scatter
(用于 Attention 层之后的输出聚合):- 做什么?
- Reduce: 每个设备持有 Attention 层输出的一个部分结果(张量块)。
reduce_scatter
首先对这些分散在不同设备上的部分结果进行逐元素求和(Reduce)。 - Scatter: 求和完成后,每个设备只获得最终完整输出张量中属于自己的那一块。
- Reduce: 每个设备持有 Attention 层输出的一个部分结果(张量块)。
- 为什么用在 Attention 后?
- Attention 层的输出通常会被立即输入给后续的 MLP 层。
- 在 TP 中,MLP 层的输入也需要在设备间进行拆分(每个设备处理输入的一部分)。
reduce_scatter
完美地完成了两件事:- 聚合了 Attention 的所有部分输出(Reduce)。
- 将聚合后的完整输出按 MLP 层需要的划分方式重新分发(Scatter)给 TP Group 内的每个设备。每个设备现在只持有完整 MLP 输入张量中自己需要处理的那一部分。
- 通信量优势: 输出张量大小为
[S, H]
(序列长度 x 隐藏层大小)。在world_size=N
的 TP Group 中:- 传统
all_reduce
通信量:~ 2 * (N - 1) * (S * H) / N
(近似值,考虑带宽优化算法)。 reduce_scatter
通信量:(S * H)
(每个设备发送自己的[S, H/N]
块,接收聚合后的[S, H/N]
块。总的通信量是N * (S * H / N) = S * H
)。通信量显著低于all_reduce
。
- 传统
- 做什么?
-
all_gather
(用于 MLP 层之后的输出聚合):- 做什么?
- 每个设备计算完 MLP 层后,得到输出张量的一个部分结果(张量块)。
- Gather: 每个设备需要收集 TP Group 内所有其他设备计算出的部分结果。
- All: 收集完成后,每个设备都拥有了完整的输出张量。
- 为什么用在 MLP 后?
- MLP 层的输出通常是该 Transformer Block 的最终输出,需要作为下一个 Block 的输入,或者用于计算损失。
- 下一个 Block(通常是 Attention)的输入需要是完整的张量(在序列长度维度上不拆分,通常在隐藏层维度拆分)。
all_gather
让每个设备都获得了完整的输出张量,为下一个 Block 的计算做好准备。
- 通信量优势: 输出张量大小为
[S, H]
。- 传统
all_reduce
通信量:~ 2 * (N - 1) * (S * H) / N
。 all_gather
通信量:(S * H)
(每个设备发送自己的[S, H/N]
块,接收其他所有设备的[S, H/N]
块,最终本地拼接成[S, H]
。总通信量N * (S * H / N) = S * H
)。通信量同样显著低于all_reduce
。
- 传统
- 做什么?
-
-
流程总结 (一个 Transformer Block 内)
假设一个 TP Group 有 N 个设备,处理一个输入张量X
(形状[S, H]
),在隐藏层维度H
上拆分:- Attention 层计算:
- 每个设备用自己的部分参数计算,得到部分输出
partial_attn_out_i
(形状[S, H/N]
)。
- 每个设备用自己的部分参数计算,得到部分输出
- Attention 后通信 (
reduce_scatter
):- 在 Attn TP Domain 内进行
reduce_scatter
。 - 结果:每个设备获得
mlp_input_i
(形状[S, H/N]
),这是 完整 Attention 输出在H
维度上拆分后属于该设备的那一部分。这个结果正好是下一个 MLP 层需要的输入拆分形式。
- 在 Attn TP Domain 内进行
- MLP 层计算:
- 每个设备用自己的 MLP 部分参数计算
mlp_input_i
,得到部分输出partial_mlp_out_i
(形状[S, H/N]
)。
- 每个设备用自己的 MLP 部分参数计算
- MLP 后通信 (
all_gather
):- 在 MLP TP Domain 内进行
all_gather
。 - 结果:每个设备获得完整的 MLP 输出
Y
(形状[S, H]
)。这个完整的Y
将作为下一个 Transformer Block 的输入(同样需要在H
维度上拆分)。
- 在 MLP TP Domain 内进行
- Attention 层计算:
-
关键点
reduce_scatter
用于 Attention 后: 目的是聚合并分发结果,为紧接着的需要输入拆分的 MLP 层做准备。它产生的是部分结果(给下一个层用)。all_gather
用于 MLP 后: 目的是收集完整的输出结果,为下一个可能需要完整输入(或不同拆分方式)的层(如下一个 Block 的 Attention)做准备。它产生的是完整结果。- 通信量优势: 组合使用
reduce_scatter
+all_gather
代替两个all_reduce
,其总通信量 (2 * S * H
) 远低于两个all_reduce
(~ 4 * (N - 1) * (S * H) / N
)。这是现代大规模 Transformer 训练的关键优化之一。 - “域”的含义: “Attn TP 域” 和 “MLP TP 域” 通常指的是同一个 TP Group,只是强调通信发生在负责 Attention 计算或 MLP 计算的这些设备之间。
简单来说:
这句话的意思是:在分布式训练的张量并行中,为了高效通信,不再在 Attention 层和 MLP 层后面都使用昂贵的 all_reduce
。而是改为:
- 在 Attention 层计算完后,在其 TP Group 内进行
reduce_scatter
。这步把各设备算出的 Attention 部分结果汇总起来,并按需拆分,得到的结果正好直接喂给每个设备上负责下一部分(MLP)的计算。 - 在 MLP 层计算完后,在其 TP Group 内进行
all_gather
。这步让每个设备收集齐所有其他设备算出的 MLP 部分结果,拼成完整的输出张量,供后续计算(如下一个 Block)使用。
这种策略 (reduce_scatter
+ all_gather
) 比在每个层后都用 all_reduce
(all_reduce
+ all_reduce
) 通信量更小,效率更高。
下面通过一个矩阵示例,结合分布式训练中的张量并行(TP)技术,说明 Attention 后使用 reduce_scatter
和 MLP 后使用 all_gather
的通信过程。假设:
- 隐藏层维度
H = 4
(为简化,忽略序列长度维度) - TP 度 = 2(2 个设备:GPU0 和 GPU1)
- Attention 输出:设备间按特征维度拆分,且是部分和(需求和)
- MLP 输出:设备间按特征维度拆分(无需求和)
1. Attention 层后的 reduce_scatter
流程
假设 Attention 层的输出是 部分和(因张量并行计算导致)。每个设备持有部分输出:
- GPU0 的 Attention 输出
Y0
:[[a00, a01, a02, a03], [a10, a11, a12, a13]]
(形状[2, 4]
,但实际是完整输出的部分和) - GPU1 的 Attention 输出
Y1
:[[b00, b01, b02, b03], [b10, b11, b12, b13]]
(同形状,也是部分和)
步骤 1:设备拆分本地张量
- GPU0 将
Y0
按列拆成两半:Y00 = [[a00, a01], [a10, a11]]
(前 2 列)Y01 = [[a02, a03], [a12, a13]]
(后 2 列)
- GPU1 将
Y1
按列拆成两半:Y10 = [[b00, b01], [b10, b11]]
(前 2 列)Y11 = [[b02, b03], [b12, b13]]
(后 2 列)
步骤 2:设备间交换数据
- GPU0 发送
Y01
给 GPU1 - GPU1 发送
Y10
给 GPU0
步骤 3:本地求和计算
- GPU0 计算前 2 列的和:
part0 = Y00 + Y10 = [[a00 + b00, a01 + b01], [a10 + b10, a11 + b11]]
- GPU1 计算后 2 列的和:
part1 = Y01 + Y11 = [[a02 + b02, a03 + b03], [a12 + b12, a13 + b13]]
步骤 4:添加残差输入
假设残差输入 X
是完整的(每个设备都有副本):
- GPU0 取
X
的前 2 列:X0 = [[x00, x01], [x10, x11]]
计算:result0 = X0 + part0
result0 = [[x00 + a00 + b00, x01 + a01 + b01],[x10 + a10 + b10, x11 + a11 + b11]]
- GPU1 取
X
的后 2 列:X1 = [[x02, x03], [x12, x13]]
计算:result1 = X1 + part1
result1 = [[x02 + a02 + b02, x03 + a03 + b03],[x12 + a12 + b12, x13 + a13 + b13]]
结果
- GPU0 持有
result0
(最终输出的前 2 列) - GPU1 持有
result1
(最终输出的后 2 列) - 通信量:仅交换拆分后的部分(
2 × 2
矩阵),总数据量为S × H = 2 × 4 = 8
。
2. MLP 层后的 all_gather
流程
MLP 层的输入是拆分后的 result0
和 result1
。假设 MLP 的输出按特征维度拆分:
- GPU0 的 MLP 输出
Z0
:[[c00, c01], [c10, c11]]
(前 2 列) - GPU1 的 MLP 输出
Z1
:[[c02, c03], [c12, c13]]
(后 2 列)
步骤:设备间交换数据
- GPU0 发送
Z0
给 GPU1 - GPU1 发送
Z1
给 GPU0
结果
- GPU0 和 GPU1 均获得完整输出:
full_output = [[c00, c01, c02, c03],[c10, c11, c12, c13]]
- 通信量:每个设备发送本地部分(
2 × 2
矩阵),总数据量为S × H = 2 × 4 = 8
。
通信效率对比
- 传统方案(每层后使用
all_reduce
):- Attention 后
all_reduce
通信量:~2 × (S × H)
(对部分和求和) - MLP 后
all_reduce
通信量:~2 × (S × H)
(对部分和求和) - 总通信量:
~4 × (S × H)
- Attention 后
- 优化方案(
reduce_scatter
+all_gather
):- Attention 后
reduce_scatter
通信量:S × H
- MLP 后
all_gather
通信量:S × H
- 总通信量:
2 × (S × H)
- Attention 后
- 优势:优化方案减少约 50% 的通信量。
关键点总结
操作 | 目的 | 输入 | 输出 | 设备间关系 |
---|---|---|---|---|
reduce_scatter | 聚合 Attention 输出并分发给 MLP 层 | 部分和(需求和) | 拆分后的结果(部分列) | 聚合 + 拆分 |
all_gather | 收集 MLP 输出供下一层使用 | 拆分后的部分列 | 完整的输出 | 收集并拼接完整数据 |
通过这种设计,分布式训练在保持计算正确性的同时,显著降低了通信开销。
下面通过矩阵示例详细说明如何将残差连接(Residual Connection)、层归一化(LayerNorm)和后续的MLP层输入准备融合在单个 reduce_scatter
操作中。这种优化策略常见于现代大模型训练框架(如Megatron-LM),核心思想是将数学计算与通信重叠,减少中间结果的存储和通信次数。
Fused Residual LayerNorm with Reduce-Scatter
优化原理:融合残差、LayerNorm和reduce_scatter
假设:
- 隐藏层维度
H = 4
(简化为4维向量) - TP度 = 2(2个设备:GPU0, GPU1)
- 序列长度
S = 2
(每个token对应一个4维向量) - 输入:
Attention
输出(部分和):[S, H]
- 残差输入(完整):
[S, H]
- 目标:得到LayerNorm后的输出,且已按MLP层需要的拆分方式分布(每个设备持有
[S, H/TP]
)
传统方案(分步操作,效率低)
- 问题:
reduce_scatter
后得到的是部分结果(如[S, H/TP]
),但LayerNorm需要完整向量计算均值和方差,需额外通信。
优化方案:融合计算与通信
关键步骤:
- 本地残差预加:每个设备将本地Attention输出与完整残差的对应部分相加。
- 本地统计量计算:基于预加结果,计算本地均值和方差所需的部分和(
sum
和square_sum
)。 - 融合通信:通过
reduce_scatter
同时完成:- 聚合Attention输出
- 分发聚合结果(MLP输入)
- 聚合LayerNorm的统计量
- 全局归一化:用聚合后的统计量归一化本地结果。
矩阵示例(H=4, TP=2, S=2)
输入数据:
- Attention输出(部分和):
- GPU0:
A0 = [[a00, a01, a02, a03], [a10, a11, a12, a13]]
- GPU1:
A1 = [[b00, b01, b02, b03], [b10, b11, b12, b13]]
- GPU0:
- 残差输入(完整,每个设备均有副本):
R = [[x00, x01, x02, x03], [x10, x11, x12, x13]]
步骤1: 本地残差预加
每个设备将本地Attention输出与残差输入的对应列相加(按TP拆分计划):
-
GPU0(负责前2列):
Z0 = A0[:, 0:2] + R[:, 0:2]
= [[a00 + x00, a01 + x01], [a10 + x10, a11 + x11]]
-
GPU1(负责后2列):
Z1 = A1[:, 2:4] + R[:, 2:4]
= [[a02 + x02, a03 + x03], [a12 + x12, a13 + x13]]
✅ 优化点:直接操作局部数据,避免完整张量计算。
步骤2: 本地统计量计算
LayerNorm需要每个token向量的全局均值和方差:
- 均值公式:
μ = Σ(z_i) / H
- 方差公式:
σ² = Σ(z_i²) / H - μ²
每个设备计算本地部分的 部分和
与 部分平方和
:
-
GPU0:
sum0 = [ (a00+x00) + (a01+x01), (a10+x10) + (a11+x11) ] = [s00, s10]
sq_sum0 = [ (a00+x00)² + (a01+x01)², (a10+x10)² + (a11+x11)² ] = [q00, q10]
-
GPU1:
sum1 = [ (a02+x02) + (a03+x03), (a12+x12) + (a13+x13) ] = [s01, s11]
sq_sum1 = [ (a02+x02)² + (a03+x03)², (a12+x12)² + (a13+x13)² ] = [q01, q11]
步骤3: 融合通信(reduce_scatter)
通过一次 reduce_scatter
同时完成:
- 聚合Attention输出(求和)
- 分发结果(按列拆分)
- 聚合统计量(
sum
和sq_sum
)
- GPU0发送:
Z0
的后半部分 +sum0
+sq_sum0
- GPU1发送:
Z1
的前半部分 +sum1
+sq_sum1
通信后结果:
-
GPU0获得:
- 聚合的Attention输出(前2列):
attn_out0 = [[a00 + b00, a01 + b01], [a10 + b10, a11 + b11]]
- 全局统计量(每个token):
global_sum = [s00 + s01, s10 + s11]
global_sq_sum = [q00 + q01, q10 + q11]
- 聚合的Attention输出(前2列):
-
GPU1获得:
- 聚合的Attention输出(后2列):
attn_out1 = [[a02 + b02, a03 + b03], [a12 + b12, a13 + b13]]
- 全局统计量(同GPU0)
- 聚合的Attention输出(后2列):
✅ 优化点:统计量聚合(标量)的通信开销极小,融合后不增加额外成本。
步骤4: 全局归一化
每个设备使用聚合后的统计量归一化本地结果:
# GPU0计算(以第一个token为例)
μ = global_sum[0] / H # 全局均值
σ² = global_sq_sum[0]/H - μ² # 全局方差
z0_normalized = (attn_out0[0] - μ) / sqrt(σ² + ε) # 归一化# GPU1同理
最终得到:
- GPU0: 归一化后的前2列
[[z00_norm, z01_norm], [z10_norm, z11_norm]]
- GPU1: 归一化后的后2列
[[z02_norm, z03_norm], [z12_norm, z13_norm]]
通信效率对比
方案 | 通信操作 | 通信量 |
---|---|---|
传统方案 | reduce_scatter + all_reduce(统计量) | S*H + 2*S |
融合方案 | 单次reduce_scatter | S*H (统计量内嵌) |
其中
S*H
是Attention聚合的通信量,2*S
是统计量聚合的通信量。融合后节省2*S
的通信开销。
关键优化总结
- 残差预加:在通信前将残差加到本地Attention输出,避免完整张量操作。
- 统计量内嵌:将LayerNorm的局部统计量计算与
reduce_scatter
融合,避免额外的all_reduce
。 - 归一化本地化:用聚合后的全局统计量在本地完成归一化。
这种设计显著减少了通信次数和内存占用,是Megatron-LM等框架实现千亿级模型高效训练的核心技术之一。