当前位置: 首页 > backend >正文

【DeepSeek 学大模型推理】Fused Residual LayerNorm with Reduce-Scatter

这段描述涉及分布式训练中的张量并行(Tensor Parallelism, TP)技术,具体是注意力(Attention)层多层感知机(MLP)层在不同设备(TP Group)间的通信优化。核心在于用 reduce_scatterall_gather 这两个集合通信操作替代传统的 all_reduce,从而降低通信开销并提升效率。

以下是详细解释:

  1. 背景:张量并行(TP)

    • 在大模型训练中,单个设备(GPU)可能无法容纳整个模型(尤其是参数和激活值)。
    • 张量并行 是将模型的单个层(如 Attention 层、MLP 层)的参数和计算水平拆分到多个设备(称为一个 TP GroupTP Domain)上。
    • 每个设备只持有该层的一部分参数,并负责该层输入张量对应部分的计算。
    • 为了得到完整的输出结果(供下一层使用或用于损失计算),这些设备之间需要进行通信来交换或聚合计算结果。
  2. 问题:传统通信 (all_reduce)

    • 在早期的 TP 实现(如 Megatron-LM 的早期版本)中,一个层(无论是 Attention 还是 MLP)计算完毕后,通常会在其 TP Group 内进行一次 all_reduce 操作。
    • all_reduce 做什么?
      • Reduce: 将所有设备上的输出张量求和(或其他操作,但通常是求和)。
      • All: 将求和后的完整结果发送回所有设备。
    • 缺点: all_reduce 的通信量是 world_size * tensor_size (其中 world_size 是 TP Group 大小)。虽然比 reduce + broadcast 高效,但对于大模型和大张量,其通信开销仍然巨大。
  3. 优化方案:reduce_scatter + all_gather
    为了减少通信开销,现代框架(如 Megatron-LM, DeepSpeed)采用了更精细的通信策略,针对 Attention 层和 MLP 层的不同计算特性,分别使用不同的通信原语

    • reduce_scatter (用于 Attention 层之后的输出聚合):

      • 做什么?
        • Reduce: 每个设备持有 Attention 层输出的一个部分结果(张量块)。reduce_scatter 首先对这些分散在不同设备上的部分结果进行逐元素求和(Reduce)。
        • Scatter: 求和完成后,每个设备只获得最终完整输出张量中属于自己的那一块
      • 为什么用在 Attention 后?
        • Attention 层的输出通常会被立即输入给后续的 MLP 层。
        • 在 TP 中,MLP 层的输入也需要在设备间进行拆分(每个设备处理输入的一部分)。
        • reduce_scatter 完美地完成了两件事:
          1. 聚合了 Attention 的所有部分输出(Reduce)。
          2. 将聚合后的完整输出按 MLP 层需要的划分方式重新分发(Scatter)给 TP Group 内的每个设备。每个设备现在只持有完整 MLP 输入张量中自己需要处理的那一部分。
      • 通信量优势: 输出张量大小为 [S, H] (序列长度 x 隐藏层大小)。在 world_size=N 的 TP Group 中:
        • 传统 all_reduce 通信量:~ 2 * (N - 1) * (S * H) / N (近似值,考虑带宽优化算法)。
        • reduce_scatter 通信量:(S * H) (每个设备发送自己的 [S, H/N] 块,接收聚合后的 [S, H/N] 块。总的通信量是 N * (S * H / N) = S * H)。通信量显著低于 all_reduce
    • all_gather (用于 MLP 层之后的输出聚合):

      • 做什么?
        • 每个设备计算完 MLP 层后,得到输出张量的一个部分结果(张量块)。
        • Gather: 每个设备需要收集 TP Group 内所有其他设备计算出的部分结果。
        • All: 收集完成后,每个设备都拥有了完整的输出张量
      • 为什么用在 MLP 后?
        • MLP 层的输出通常是该 Transformer Block 的最终输出,需要作为下一个 Block 的输入,或者用于计算损失。
        • 下一个 Block(通常是 Attention)的输入需要是完整的张量(在序列长度维度上不拆分,通常在隐藏层维度拆分)。
        • all_gather 让每个设备都获得了完整的输出张量,为下一个 Block 的计算做好准备。
      • 通信量优势: 输出张量大小为 [S, H]
        • 传统 all_reduce 通信量:~ 2 * (N - 1) * (S * H) / N
        • all_gather 通信量:(S * H) (每个设备发送自己的 [S, H/N] 块,接收其他所有设备的 [S, H/N] 块,最终本地拼接成 [S, H]。总通信量 N * (S * H / N) = S * H)。通信量同样显著低于 all_reduce
  4. 流程总结 (一个 Transformer Block 内)
    假设一个 TP Group 有 N 个设备,处理一个输入张量 X (形状 [S, H]),在隐藏层维度 H 上拆分:

    1. Attention 层计算:
      • 每个设备用自己的部分参数计算,得到部分输出 partial_attn_out_i (形状 [S, H/N])。
    2. Attention 后通信 (reduce_scatter):
      • Attn TP Domain 内进行 reduce_scatter
      • 结果:每个设备获得 mlp_input_i (形状 [S, H/N]),这是 完整 Attention 输出H 维度上拆分后属于该设备的那一部分。这个结果正好是下一个 MLP 层需要的输入拆分形式。
    3. MLP 层计算:
      • 每个设备用自己的 MLP 部分参数计算 mlp_input_i,得到部分输出 partial_mlp_out_i (形状 [S, H/N])。
    4. MLP 后通信 (all_gather):
      • MLP TP Domain 内进行 all_gather
      • 结果:每个设备获得完整的 MLP 输出 Y (形状 [S, H])。这个完整的 Y 将作为下一个 Transformer Block 的输入(同样需要在 H 维度上拆分)。
  5. 关键点

    • reduce_scatter 用于 Attention 后: 目的是聚合并分发结果,为紧接着的需要输入拆分的 MLP 层做准备。它产生的是部分结果(给下一个层用)。
    • all_gather 用于 MLP 后: 目的是收集完整的输出结果,为下一个可能需要完整输入(或不同拆分方式)的层(如下一个 Block 的 Attention)做准备。它产生的是完整结果
    • 通信量优势: 组合使用 reduce_scatter + all_gather 代替两个 all_reduce,其总通信量 (2 * S * H) 远低于两个 all_reduce (~ 4 * (N - 1) * (S * H) / N)。这是现代大规模 Transformer 训练的关键优化之一。
    • “域”的含义: “Attn TP 域” 和 “MLP TP 域” 通常指的是同一个 TP Group,只是强调通信发生在负责 Attention 计算或 MLP 计算的这些设备之间。

简单来说:

这句话的意思是:在分布式训练的张量并行中,为了高效通信,不再在 Attention 层和 MLP 层后面都使用昂贵的 all_reduce。而是改为:

  1. 在 Attention 层计算完后,在其 TP Group 内进行 reduce_scatter。这步把各设备算出的 Attention 部分结果汇总起来,并按需拆分,得到的结果正好直接喂给每个设备上负责下一部分(MLP)的计算。
  2. 在 MLP 层计算完后,在其 TP Group 内进行 all_gather。这步让每个设备收集齐所有其他设备算出的 MLP 部分结果,拼成完整的输出张量,供后续计算(如下一个 Block)使用。

这种策略 (reduce_scatter + all_gather) 比在每个层后都用 all_reduce (all_reduce + all_reduce) 通信量更小,效率更高

下面通过一个矩阵示例,结合分布式训练中的张量并行(TP)技术,说明 Attention 后使用 reduce_scatter 和 MLP 后使用 all_gather 的通信过程。假设:

  • 隐藏层维度 H = 4(为简化,忽略序列长度维度)
  • TP 度 = 2(2 个设备:GPU0 和 GPU1)
  • Attention 输出:设备间按特征维度拆分,且是部分和(需求和)
  • MLP 输出:设备间按特征维度拆分(无需求和)

1. Attention 层后的 reduce_scatter 流程

假设 Attention 层的输出是 部分和(因张量并行计算导致)。每个设备持有部分输出:

  • GPU0 的 Attention 输出 Y0[[a00, a01, a02, a03], [a10, a11, a12, a13]]
    (形状 [2, 4],但实际是完整输出的部分和)
  • GPU1 的 Attention 输出 Y1[[b00, b01, b02, b03], [b10, b11, b12, b13]]
    (同形状,也是部分和)
步骤 1:设备拆分本地张量
  • GPU0Y0 按列拆成两半:
    • Y00 = [[a00, a01], [a10, a11]](前 2 列)
    • Y01 = [[a02, a03], [a12, a13]](后 2 列)
  • GPU1Y1 按列拆成两半:
    • Y10 = [[b00, b01], [b10, b11]](前 2 列)
    • Y11 = [[b02, b03], [b12, b13]](后 2 列)
步骤 2:设备间交换数据
  • GPU0 发送 Y01 给 GPU1
  • GPU1 发送 Y10 给 GPU0
步骤 3:本地求和计算
  • GPU0 计算前 2 列的和:
    part0 = Y00 + Y10 = [[a00 + b00, a01 + b01], [a10 + b10, a11 + b11]]
    
  • GPU1 计算后 2 列的和:
    part1 = Y01 + Y11 = [[a02 + b02, a03 + b03], [a12 + b12, a13 + b13]]
    
步骤 4:添加残差输入

假设残差输入 X 是完整的(每个设备都有副本):

  • GPU0X 的前 2 列:X0 = [[x00, x01], [x10, x11]]
    计算:result0 = X0 + part0
    result0 = [[x00 + a00 + b00, x01 + a01 + b01],[x10 + a10 + b10, x11 + a11 + b11]]
    
  • GPU1X 的后 2 列:X1 = [[x02, x03], [x12, x13]]
    计算:result1 = X1 + part1
    result1 = [[x02 + a02 + b02, x03 + a03 + b03],[x12 + a12 + b12, x13 + a13 + b13]]
    
结果
  • GPU0 持有 result0(最终输出的前 2 列)
  • GPU1 持有 result1(最终输出的后 2 列)
  • 通信量:仅交换拆分后的部分(2 × 2 矩阵),总数据量为 S × H = 2 × 4 = 8

2. MLP 层后的 all_gather 流程

MLP 层的输入是拆分后的 result0result1。假设 MLP 的输出按特征维度拆分:

  • GPU0 的 MLP 输出 Z0[[c00, c01], [c10, c11]](前 2 列)
  • GPU1 的 MLP 输出 Z1[[c02, c03], [c12, c13]](后 2 列)
步骤:设备间交换数据
  • GPU0 发送 Z0 给 GPU1
  • GPU1 发送 Z1 给 GPU0
结果
  • GPU0 和 GPU1 均获得完整输出
    full_output = [[c00, c01, c02, c03],[c10, c11, c12, c13]]
    
  • 通信量:每个设备发送本地部分(2 × 2 矩阵),总数据量为 S × H = 2 × 4 = 8

通信效率对比

  • 传统方案(每层后使用 all_reduce):
    • Attention 后 all_reduce 通信量:~2 × (S × H)(对部分和求和)
    • MLP 后 all_reduce 通信量:~2 × (S × H)(对部分和求和)
    • 总通信量~4 × (S × H)
  • 优化方案reduce_scatter + all_gather):
    • Attention 后 reduce_scatter 通信量:S × H
    • MLP 后 all_gather 通信量:S × H
    • 总通信量2 × (S × H)
  • 优势:优化方案减少约 50% 的通信量。

关键点总结

操作目的输入输出设备间关系
reduce_scatter聚合 Attention 输出并分发给 MLP 层部分和(需求和)拆分后的结果(部分列)聚合 + 拆分
all_gather收集 MLP 输出供下一层使用拆分后的部分列完整的输出收集并拼接完整数据

通过这种设计,分布式训练在保持计算正确性的同时,显著降低了通信开销。

下面通过矩阵示例详细说明如何将残差连接(Residual Connection)、层归一化(LayerNorm)和后续的MLP层输入准备融合在单个 reduce_scatter 操作中。这种优化策略常见于现代大模型训练框架(如Megatron-LM),核心思想是将数学计算与通信重叠,减少中间结果的存储和通信次数


Fused Residual LayerNorm with Reduce-Scatter

优化原理:融合残差、LayerNorm和reduce_scatter

假设:

  • 隐藏层维度 H = 4(简化为4维向量)
  • TP度 = 2(2个设备:GPU0, GPU1)
  • 序列长度 S = 2(每个token对应一个4维向量)
  • 输入:
    • Attention 输出(部分和):[S, H]
    • 残差输入(完整):[S, H]
  • 目标:得到LayerNorm后的输出,且已按MLP层需要的拆分方式分布(每个设备持有 [S, H/TP]

传统方案(分步操作,效率低)

Attention输出
Reduce-Scatter
残差相加
LayerNorm
MLP输入
  • 问题reduce_scatter 后得到的是部分结果(如 [S, H/TP]),但LayerNorm需要完整向量计算均值和方差,需额外通信。

优化方案:融合计算与通信

Attention输出
本地残差预加
本地统计量计算
Reduce-Scatter
全局统计量+归一化
MLP输入

关键步骤

  1. 本地残差预加:每个设备将本地Attention输出与完整残差的对应部分相加。
  2. 本地统计量计算:基于预加结果,计算本地均值和方差所需的部分和(sumsquare_sum)。
  3. 融合通信:通过 reduce_scatter 同时完成:
    • 聚合Attention输出
    • 分发聚合结果(MLP输入)
    • 聚合LayerNorm的统计量
  4. 全局归一化:用聚合后的统计量归一化本地结果。

矩阵示例(H=4, TP=2, S=2)

输入数据
  • Attention输出(部分和)
    • GPU0: A0 = [[a00, a01, a02, a03], [a10, a11, a12, a13]]
    • GPU1: A1 = [[b00, b01, b02, b03], [b10, b11, b12, b13]]
  • 残差输入(完整,每个设备均有副本)
    • R = [[x00, x01, x02, x03], [x10, x11, x12, x13]]

步骤1: 本地残差预加

每个设备将本地Attention输出与残差输入的对应列相加(按TP拆分计划):

  • GPU0(负责前2列):
    Z0 = A0[:, 0:2] + R[:, 0:2]
    = [[a00 + x00, a01 + x01], [a10 + x10, a11 + x11]]

  • GPU1(负责后2列):
    Z1 = A1[:, 2:4] + R[:, 2:4]
    = [[a02 + x02, a03 + x03], [a12 + x12, a13 + x13]]

优化点:直接操作局部数据,避免完整张量计算。


步骤2: 本地统计量计算

LayerNorm需要每个token向量的全局均值方差

  • 均值公式:μ = Σ(z_i) / H
  • 方差公式:σ² = Σ(z_i²) / H - μ²

每个设备计算本地部分的 部分和部分平方和

  • GPU0

    • sum0 = [ (a00+x00) + (a01+x01), (a10+x10) + (a11+x11) ] = [s00, s10]
    • sq_sum0 = [ (a00+x00)² + (a01+x01)², (a10+x10)² + (a11+x11)² ] = [q00, q10]
  • GPU1

    • sum1 = [ (a02+x02) + (a03+x03), (a12+x12) + (a13+x13) ] = [s01, s11]
    • sq_sum1 = [ (a02+x02)² + (a03+x03)², (a12+x12)² + (a13+x13)² ] = [q01, q11]

步骤3: 融合通信(reduce_scatter)

通过一次 reduce_scatter 同时完成:

  1. 聚合Attention输出(求和)
  2. 分发结果(按列拆分)
  3. 聚合统计量sumsq_sum
  • GPU0发送Z0 的后半部分 + sum0 + sq_sum0
  • GPU1发送Z1 的前半部分 + sum1 + sq_sum1

通信后结果

  • GPU0获得

    • 聚合的Attention输出(前2列):
      attn_out0 = [[a00 + b00, a01 + b01], [a10 + b10, a11 + b11]]
    • 全局统计量(每个token):
      global_sum = [s00 + s01, s10 + s11]
      global_sq_sum = [q00 + q01, q10 + q11]
  • GPU1获得

    • 聚合的Attention输出(后2列):
      attn_out1 = [[a02 + b02, a03 + b03], [a12 + b12, a13 + b13]]
    • 全局统计量(同GPU0)

优化点:统计量聚合(标量)的通信开销极小,融合后不增加额外成本。


步骤4: 全局归一化

每个设备使用聚合后的统计量归一化本地结果:

# GPU0计算(以第一个token为例)
μ = global_sum[0] / H           # 全局均值
σ² = global_sq_sum[0]/H - μ²    # 全局方差
z0_normalized = (attn_out0[0] - μ) / sqrt(σ² + ε)  # 归一化# GPU1同理

最终得到:

  • GPU0: 归一化后的前2列 [[z00_norm, z01_norm], [z10_norm, z11_norm]]
  • GPU1: 归一化后的后2列 [[z02_norm, z03_norm], [z12_norm, z13_norm]]

通信效率对比

方案通信操作通信量
传统方案reduce_scatter + all_reduce(统计量)S*H + 2*S
融合方案单次reduce_scatterS*H(统计量内嵌)

其中 S*H 是Attention聚合的通信量,2*S 是统计量聚合的通信量。融合后节省 2*S 的通信开销。


关键优化总结

  1. 残差预加:在通信前将残差加到本地Attention输出,避免完整张量操作。
  2. 统计量内嵌:将LayerNorm的局部统计量计算与 reduce_scatter 融合,避免额外的 all_reduce
  3. 归一化本地化:用聚合后的全局统计量在本地完成归一化。

这种设计显著减少了通信次数和内存占用,是Megatron-LM等框架实现千亿级模型高效训练的核心技术之一。

http://www.xdnf.cn/news/12078.html

相关文章:

  • MySQL事务:从ACID特性到高并发优化的深度解析
  • day 44
  • K8S主机漏洞扫描时检测到kube-服务目标SSL证书已过期漏洞的一种永久性修复方法
  • 【论文写作】如何撰写基于模型拼接(A+B)的创新性论文
  • leetcode 二叉搜索树中第k小的元素 java
  • SiFli 567+emmc Standby休眠报错问题
  • 重装系统+驱动+磁盘分区
  • day19 leetcode-hot100-37(二叉树2)
  • 5.29-6.4解决问题归纳
  • 银行用户信誉等级
  • 前端面试宝典---vite原理解析
  • Numpy——结构化数组和Numpy文件
  • 【电赛培训课程】电子设计竞赛工程基础知识
  • 使用qt 定义全局钩子 捕获系统的键盘事件
  • 《人性的弱点》核心总结
  • AI基础认知
  • 【Python连接数据库基础 06】Pandas与SQL协同:解锁大规模数据处理新境界,让分析效率飙升10倍
  • 代理IP:6G标准化进程中的隐形推手
  • 如何在 React 中监听 div 的滚动事件
  • Pendulum:优雅处理 Python 中的日期与时间
  • vue3动态插入iframe,内容被取消的原因
  • pack 布局管理器
  • 第十三节:第三部分:集合框架:Map集合的遍历方式
  • 数码相片冲印规格参考表
  • Docker load 后镜像名称为空问题的解决方案
  • 国芯思辰ADE芯片成功替代ADS1296R,除颤仪核心部件实现自主可控
  • git删除本地分支和远程分支
  • 非对称加密
  • MuLogin浏览器如何使用Loongproxy?
  • 【AI系列】DPO 与 PPO 的比较与分析