【大模型推理学习】flashMLA (二)
FlashMLA 的 Stream-K 优化 是解决序列处理(尤其是变长序列)中负载不均衡和计算碎片问题的关键技术。其核心思想在于沿序列维度而非 Batch 维度进行任务划分,并根据可用 SM 资源动态分配任务。以下是其关键优化点详解:
-
核心问题:Batch 维划分的缺陷
- 传统 GEMM 优化通常沿 Batch 维划分任务(每个 SM 处理一个或多个完整序列)。
- 缺陷: 当序列长度差异巨大(变长序列常见)时:
- 负载不均衡: 处理短序列的 SM 很快完成工作而闲置;处理长序列的 SM 仍在忙碌。导致整体效率下降(“长尾效应”)。
- 计算碎片: 即使使用动态调度(如 Persistence Kernels),频繁的任务启动(多次 global waves)引入显著调度开销。
-
Stream-K 的核心优化:序列维切分 + SM 数匹配
- 任务切分维度转变: 不再将
batch_size
个序列视为batch_size
个独立任务。而是将整个 Batch 的所有序列沿着序列维度(K 或 N 维,取决于算子)拼接起来,形成一个“超级序列”。 - 均匀划分: 将这个“超级序列”的总计算量(例如,总 token 数)均匀地划分成
num_sm
个任务块(Tile)。num_sm
是 GPU 上可用的 SM 数量(减去其他并行维度如 Head 占用的 SM)。 - 任务分配: 每个任务块被分配给一个 独立的 CTA。
- CTA 数 = SM 数: 关键设置:启动的 CTA 总数 (
gridDim.x
) 被设定为num_sm
。
- 任务切分维度转变: 不再将
-
如何解决核心问题:
- 消除负载不均衡:
- 每个 CTA 处理的任务量是近似相等的(总计算量 / SM 数)。
- 短序列可能被多个 CTA 共同处理(如果它跨过了任务块边界)。
- 长序列必然被拆分成多个任务块,由多个 CTA(分布在多个 SM 上)协同完成。
- 结果: 所有 SM 的计算负载基本平衡,避免空闲等待。长序列的计算被自然地并行化到多个 SM。
- 消除计算碎片 & 减少调度开销:
- 全局 Wave 数 = 1: 因为 CTA 总数等于 SM 数,GPU 硬件调度器可以在一个全局 Wave(一次调度)内将所有 CTA 发射到所有 SM 上执行。
- 零调度开销: 没有后续 Wave 需要调度,完全消除了多 Wave 调度带来的开销。
- 无计算碎片: 所有 SM 在第一个(也是唯一一个)Wave 启动后就持续工作直到完成各自分配的任务块,没有中间空闲或小任务带来的碎片。
- 消除负载不均衡:
-
优势总结:
- 完美负载均衡: 尤其对变长序列输入效果显著,计算资源利用率最大化。
- 极低调度开销: 单 Wave 调度,消除了内核启动和 Wave 调度的主要瓶颈。
- 无计算碎片: SM 持续饱和工作。
- 隐式长序列并行: 天然地将长序列的计算分布到多个 SM,无需特殊处理。
- 资源利用高效: CTA 数精确匹配物理 SM 数,避免资源争抢或浪费。
-
应用场景:
- FlashMLA: 核心优化手段,用于处理注意力计算中的变长 K/V 序列。
- Marlin Kernel: 用于 4-bit 权重反量化 GEMM,同样受益于负载均衡和低调度开销。
- FA3 (Fused Attention 训练算子): 在训练阶段处理变长序列时应用 Stream-K 优化 GEMM 部分。
- 其他变长序列 GEMM: 任何需要处理大量长度不一序列的 GEMM 计算均可借鉴此思想。
简单比喻:
想象一个工厂(GPU)有 num_sm
条生产线(SM)。传统方法(Batch 划分)是把 batch_size
个不同长度的产品(序列)分别交给 batch_size
个小组(CTA),每组负责完整做一个产品。结果做短产品的小组早早下班,做长产品的小组累死累活,工厂效率低下。
Stream-K 的方法是:把所有产品拆解成零件(token),堆成一座零件山(超级序列)。然后把这堆零件均匀地分成 num_sm
份。工厂的 num_sm
条生产线(SM)同时开工,每条线只负责处理自己那一份零件。无论零件原来属于哪个产品(序列),每条线的工作量几乎相同。工厂一次启动(单 Wave)所有生产线,大家同时开始、同时结束(理想情况),没有等待,效率最高。
总结: FlashMLA 的 Stream-K 优化通过沿序列维度均匀切分计算负载,并严格将 CTA 数量设置为可用 SM 数量,实现了在变长序列处理上的近乎完美的负载均衡、单 Wave 零调度开销和无计算碎片,显著提升了 GPU 在 LLM 推理和训练中关键算子的计算效率。