当前位置: 首页 > web >正文

GEMM inTriton (Split-K and Stream-K)

Triton是OpenAI的开源项目。官网https://openai.com/index/triton/。Github地址https://github.com/triton-lang/triton。自问世来,一直以来都受到业界关注,而且近年来热度似乎有了明显提升。可以看到将Triton用于LLM的例子越来越多。各种流行的LLM框架,如vLLM,SGLang和TRT-LLM中也都有了Triton的身影。PyTorch也对它进行了官方支持。在PyTorch中Triton可用于自定义算子的开发并方便地与torch.compile集成(https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)

它主要解决的是在并行加速芯片上写高性能算子的问题。像CUDA这样的编程接口易学难精。写个能工作的实现和写出具有SOTA性能的kernel所需的专业时间和工程精力差别巨大。Triton的定位就是以极小的工程代价能达到手写算子约八成的性能。正如同Python与C++,或者C++与汇编的关系。将来可能大多算子会用Triton开发,只有在那些性能瓶颈的算子才会用CUDA去开发。

对于大多数使用者而言,更关心的是如何使用Triton。一个官方hello world可见:https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html。从中可以大概看到用Triton写一个kernel的基本范式与套路。Triton的DSL中有program_id的概念,对应CUDA中的CTA,也就是thread block。使用中很多时候以block为单位,这样就可以尽量少地纠缠于warp, thread等更细节的概念。

本文以最常见的计算GEMM为例,看下用Triton是如何实现它以及它的几种变体(Split-K,Stream-K)的。

Classic GEMM

作为引子,首先看下经典的用tiling来做GEMM是如何在Triton中实现的。基本的写法可参见:https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html。Kernel的CTA的个数就是输出矩阵中的tile数量。也就是说,每个CTA计算一个输出矩阵中的tile。它需要循环多次进行累加。循环次数为k维上的block数。

与CUDA有所不同的是,像这里的offs_amoffs_bn等描述的都是一个range,即下标数组。如果BLOCK_SIZE_M不能整除M的话,余数部分会从0开始。这一部分是冗余计算,最后会用mask过滤掉。整体代码比较易懂,不需要过多解释。这里稍微绕一些的可能用于L2 optimization的CTA id重映射。

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

经过这个变换(M维度上的分块)后,按CTA id递增tile的顺序变成:
在这里插入图片描述

Split-K GEMM

源码可参考https://github.com/triton-lang/triton/blob/v2.1.0/python/triton/ops/matmul.py。首先是kernel启动部分:

# launch kernel
grid = lambda META: ( cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),META["SPLIT_K"],
)
_kernel[grid](

这里grid为二维,第一个维度为output矩阵的tile数,即M维的block数 x N维的block数,第二个维度为Split-K中的K,即K维上分成几个partition。这些partition由不同的CTA计算并累加。

接下来就是kernel的定义:

@autotune(                                                                configs=[                                                             # basic configs for compute-bound matmuls                         Config(                                                           {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},num_stages=3,                                                 num_warps=8,                                                  ),                                                                ...
}
@heuristics(                                                                         {                                                                                "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }                                                                                
)                                                                                    
@jit                                                                                 
def _kernel(...)                                                                       

这里由@jit修饰的就是kernel函数的定义了。Kernel函数的参数分成几类:

  • 调用者给的,如输入输出,维度信息这些。
  • 可tuning参数,如block大小。它们与性能相关。可参见:https://triton-lang.org/main/python-api/generated/triton.autotune.html。
  • 基于heuristics得到的参数,基于预定义规则得到。如EVEN_K代表K维上的元素能否被CTA平分。可参见:https://triton-lang.org/main/python-api/generated/triton.heuristics.html。

上面这个kernel实现可分为三个部分:

  1. 根据program_id确定当前要处理的数据。
# matrix multiplication                                           
pid = tl.program_id(0)                                            
pid_z = tl.program_id(1)                                          
grid_m = tl.cdiv(M, BLOCK_M)                                      
grid_n = tl.cdiv(N, BLOCK_N)                                      
# re-order program ID for better L2 performance                   
width = GROUP_M * grid_n                                          
group_id = pid // width                                           
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)            
pid_m = group_id * GROUP_M + (pid % group_size)                   
pid_n = (pid % width) // (group_size)                             
# do matrix multiplication                                        
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)                      
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)                      
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) 
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) 
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)                      
# pointers                                                        
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)      
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)      

其中以下几行:

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)

告诉编译器数据是align且连续的。这样编译器就可以做一些诸如vectorization的优化。

  1. 对于每个输出中的tile,在K维上做进行累加。
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)                                
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):                                  if EVEN_K:                                                                     a = tl.load(A)                                                             b = tl.load(B)                                                             else:                                                                          k_remaining = K - k * (BLOCK_K * SPLIT_K)                                  _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)                            a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)                   b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)                   if AB_DTYPE is not None:                                                       a = a.to(AB_DTYPE)                                                         b = b.to(AB_DTYPE)                                                         if fp8_fast_accum:                                                             acc = tl.dot(                                                              a, b, acc, out_dtype=acc_dtype, input_precision=input_precision        )                                                                          else:                                                                          acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision)  A += BLOCK_K * SPLIT_K * stride_ak                                             B += BLOCK_K * SPLIT_K * stride_bk                                             
acc = acc.to(C.dtype.element_ty)                                                   
  1. 写回结果。
# rematerialize rm and rn to save registers                  
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)                 
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)                 
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)  
mask = (rm < M)[:, None] & (rn < N)[None, :]                 
# handles write-back with reduction-splitting                
if SPLIT_K == 1:                                             tl.store(C, acc, mask=mask)                              
else:                                                        tl.atomic_add(C, acc, mask=mask)                         

这里首先重新计算了输出矩阵中的坐标。其实前面已经算过了,但如果保留到这里就需要耗费register。

然后用mask保证不越界。最后判断如果SPLIT_K为1,即退化为非SPLIT K的情况,那就直接存结果。否则意味着多个CTA共同计算一个output矩阵的tile,那就需要做累加。又由于这些写的线程属于不同CTA,可能并行执行,因此需要使用atomic add。

Stream-K GEMM

传统的做法是以problem为出发点,将problem size进行切割后,再将它们分到并行的计算单元上。但GPU上的并行计算处理器(SM)数量是固定的。这可能导致wave quantization问题(即子问题的数量不是并行处理器的整数),浪费计算资源。举个最简单的例子,一个任务可分为11个小任务(每个小工作需要1人/天),分给10个人干。但共需要2天才能全部完成,第二个会有9个人是无事可干的。而且这种问题随着GPU的更新,会越来越严重。因为计算单元更强,意味着需要更大的子问题才能“喂饱”它。那子问题的数量自然就会更少,这样就更容易出现wave quantization的现象。

那我们是否可以将子问题切得足够小(如果M, N维不够切就用Split-K),这样就能减少或避免wave quantization。但这样虽然SM利用率可能高了,但性能未必高。Block size过小可能会导致线程内IPL机会变小,另外计算密度变小。如果这样比较抽象的话,可以看一下具体的例子:https://pytorch.org/blog/accelerating-triton/#50-warp-stalling。

Split-K的限制在于它需要K维上的block在CTA间均分。那我们可不可以换个角度,从并行计算处理器为出发点,将子问题按SM数量来切分?这样所有的任务都可以在一个wave中完成,自然也不存在wave quantization的问题。这就是Stream-K的思想。还是用上面的例子,一个任务分为11个小的子任务,将11个子任务按人数分成10份,每个人完成1.1份。Stream-K与Split-K相比,SM利用率更高。而且同步归约次数也更少(不多于SM个数)。

Stream-K算法可参见2023年的论文《Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU》的Algorithm 5。但单纯使用Stream-K可能导致tile-processing skew问题。因此实际使用时会采用称为"two-tile Stream-K + data parallel"的混合调度方法。它的基本思想是先用Stream-K将那些除不尽的可能导致wave quantization的部分在一个wave中处理掉。这样剩下的部分是能整除的,就可以用传统的高效的DP(data parallel)方法来计算了。便于理解的形象的图可参见论文中的Figure 3 (c)

接下来学习如何用Triton来实现Stream-K:https://github.com/triton-lang/triton/issues/1393。代码大体可分三步:首先根据problem计算workload在Stream-K和DP之间怎么划分。然后分别起两个kernel用Stream-K和DP来计算相应的部分。

  1. 准备阶段
M, K = a.shape
_, N = b.shape
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# compute grid (work to do per SM on the first wave)
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
iters_per_tile = triton.cdiv(K, BLK_K)
GROUP_M = 8  # 0 to disable swizzling
total_tiles = total_blocks_M * total_blocks_Nif total_programs_streamk > 0:  # Stream-K# last wave may occupy less than total_programs_streamk SMstotal_tiles_streamk = total_tiles % total_programs_streamk# for two-tile Stream-K + data-parallel from original paperif two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:total_tiles_streamk += total_programs_streamk# remaining tiles are computed using classical blockingtotal_blocking_tiles = total_tiles - total_tiles_streamktotal_iters_streamk = total_tiles_streamk * iters_per_tile# iterations related to full wavestotal_full_tiles_streamk = total_iters_streamk // total_programs_streamk# iterations related to last (partial) wavetotal_partial_tiles_streamk = total_iters_streamk % total_programs_streamk...

这里total_programs_streams是SM数量(同时也是CTA数量)。total_blocks_Mtotal_blocks_Niters_per_tile分别是M, N, K三个维度上的block数。total_tiles是输出矩阵上的tile数量。将这个tile数量除以SM数量,余数部分交给stream k处理。这个数量为total_tiles_streamk。另外根据paper中的two tiles策略解决SM间的workload imbalance问题,对于这部分每个SM还要再多算一个tile。如共21个tile,4个SM。余数为1,每个SM再多算一个tile,则stream k中,是4个SM计算5个tile。那剩下的tile(数量为total_blocking_tiles)是由传统DP方法来算。total_full_tiles_streamktotal_partial_tiles_streamk分别是Stream-K中每个SM处理的iter次数的最小值,以及余下的iter次数(意味着有些SM需要平摊这些余下的部分)。下面是一个小规模(5个tile,每个tile在K维上2个block)的切分示意图:
在这里插入图片描述
可以看到,除不尽的余数部分被分到了前几个SM。

  1. Stream-K

对应的kernel函数为first_wave。这里的grid size为SM数量。也就是说每个CTA对应一个SM。SM是可以并行执行的。

# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocates locks to sync work accross SMs
locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
k1 = first_wave[(total_programs_streamk,)](...)

下面看下该kernel的实现:

@triton.jit()                                                                                    
def first_wave(                                                                                  A, B, C,                                                                                 M, N, K,                                                                                 locks,stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,                        total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile,                   BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,GROUP_M: tl.constexpr,
):      pid = tl.program_id(0)                                                                       start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk)   last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk)while start_iter < last_iter:end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter)mac_loop(A, B, C,M, N, K,                                                                        locks,stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,iters_per_tile,                                                                 start_iter, end_iter,                                                           BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,                                            GROUP_M,                                                                        )start_iter = end_iter

它对应论文的Algorithm 5。这里有点绕的地方在于一个CTA处理可能跨了两个tile,也就是说包含了两个tile的iteration。先以简单的场景(2个CTA处理3个tile,每个tile在K维上分4个block)过一下论文中的算法。
在这里插入图片描述

回到代码来。代码中的start_iterlast_iter对应论文中iteriter_end,代表当前CTA要处理的iter范围。这里在计算当前CTA的iter范围(即start_iterlast_iter)时还要考虑无法整除的情况。

由于这些iter可能跨tile,因此要调多次mac loop完成。代码中的start_iterend_iter对应论文中的 local_iterlocal_iter_end,代表当前这次mac loop所处理的iter范围。Mac loop的主体部分就是对于给定范围在k维上做reduction。代码实现中将数据的写回部分也放到mac_loop函数中了,所以与论文中Algorithm 3中的MacLoop相比,看起来要复杂一些。

结果写回部分的复杂之处在于一个tile中的iter现在可能由不同的CTA完成了。因此得由一个CTA写,其它的做原子累加。假设某个tile中的iter会由2个CTA(标为CTA0和CTA1)共同完成。

对应地,代码中,CTA 1的mac_loop中由于条件end_iter % iters_per_tile == 0成立,会将当前的partial accumulator写入到结果矩阵C中。然后通过tl.atomic_xchg通知其它CTA。CTA 0在mac_loop中会先等待CTA 1完成的信号,然后通过tl.atomic_add将partial accumulator以原子操作的方式加到结果矩阵。示意图如下:
在这里插入图片描述

  1. DP
    这部分起第二个kernel,采用经典的分块策略计算剩余部分。该策略前面已经讲过了,这里不再累述。
k2 = full_tiles[(total_blocking_tiles,)](...)
http://www.xdnf.cn/news/4084.html

相关文章:

  • 经典的 Masked + Self-supervised learning 的模型方法
  • 学习路线(视觉)
  • Deep-Live-Cam-实时换脸开源部署和使用
  • sqli-labs靶场11-17关(POST型)
  • 小白学习java第16天(下):javaweb
  • 【C/C++】inline关键词
  • 第六章:6.1 ESP32教学:多任务处理与FreeRTOS实战
  • 谷歌SMR测试环境搭建
  • Spring 框架中 @Configuration 注解详解
  • Springboot循环依赖
  • FOC算法开环控制基础
  • Java开发者面试实录:微服务架构与Spring Cloud的应用
  • 学习黑客Nmap 原理
  • 什么是外联模板(extern template)?
  • 【阿里云大模型高级工程师ACP学习笔记】2.9 大模型应用生产实践 (下篇)
  • C++竞赛指南
  • 搜索速度迅猛,能在0.001秒内迅速找到文件,但遗憾的是,该软件已经停止更新
  • 前端- ElementPlus入门
  • yolov11 epoch100轮 训练笔记5 kaggle comet
  • Android学习总结之GetX库篇(优缺点)
  • 进程的程序替换——exec系列函数的使用
  • 效整理文件信息!一键生成文件夹目录的工具
  • 8.渐入佳境 -- 域名及网络地址
  • Unity:Surface Effector 2D(表面效应器 2D)
  • OSE2.【Linux】练习:查找项目的main函数入口
  • 开元类双端互动组件部署实战全流程教程(第3部分:UI资源加载机制与界面逻辑全面解析
  • 事务隔离(MySQL)
  • FTP(文件传输协议)
  • 15.日志分析入门
  • LeetCode算法题 (反转链表)Day17!!!C/C++