GEMM inTriton (Split-K and Stream-K)
Triton是OpenAI的开源项目。官网https://openai.com/index/triton/。Github地址https://github.com/triton-lang/triton。自问世来,一直以来都受到业界关注,而且近年来热度似乎有了明显提升。可以看到将Triton用于LLM的例子越来越多。各种流行的LLM框架,如vLLM,SGLang和TRT-LLM中也都有了Triton的身影。PyTorch也对它进行了官方支持。在PyTorch中Triton可用于自定义算子的开发并方便地与torch.compile
集成(https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
它主要解决的是在并行加速芯片上写高性能算子的问题。像CUDA这样的编程接口易学难精。写个能工作的实现和写出具有SOTA性能的kernel所需的专业时间和工程精力差别巨大。Triton的定位就是以极小的工程代价能达到手写算子约八成的性能。正如同Python与C++,或者C++与汇编的关系。将来可能大多算子会用Triton开发,只有在那些性能瓶颈的算子才会用CUDA去开发。
对于大多数使用者而言,更关心的是如何使用Triton。一个官方hello world可见:https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html。从中可以大概看到用Triton写一个kernel的基本范式与套路。Triton的DSL中有program_id
的概念,对应CUDA中的CTA,也就是thread block。使用中很多时候以block为单位,这样就可以尽量少地纠缠于warp, thread等更细节的概念。
本文以最常见的计算GEMM为例,看下用Triton是如何实现它以及它的几种变体(Split-K,Stream-K)的。
Classic GEMM
作为引子,首先看下经典的用tiling来做GEMM是如何在Triton中实现的。基本的写法可参见:https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html。Kernel的CTA的个数就是输出矩阵中的tile数量。也就是说,每个CTA计算一个输出矩阵中的tile。它需要循环多次进行累加。循环次数为k维上的block数。
与CUDA有所不同的是,像这里的offs_am
,offs_bn
等描述的都是一个range,即下标数组。如果BLOCK_SIZE_M
不能整除M
的话,余数部分会从0开始。这一部分是冗余计算,最后会用mask过滤掉。整体代码比较易懂,不需要过多解释。这里稍微绕一些的可能用于L2 optimization的CTA id重映射。
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
经过这个变换(M维度上的分块)后,按CTA id递增tile的顺序变成:
Split-K GEMM
源码可参考https://github.com/triton-lang/triton/blob/v2.1.0/python/triton/ops/matmul.py。首先是kernel启动部分:
# launch kernel
grid = lambda META: ( cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),META["SPLIT_K"],
)
_kernel[grid](
这里grid为二维,第一个维度为output矩阵的tile数,即M维的block数 x N维的block数,第二个维度为Split-K中的K,即K维上分成几个partition。这些partition由不同的CTA计算并累加。
接下来就是kernel的定义:
@autotune( configs=[ # basic configs for compute-bound matmuls Config( {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},num_stages=3, num_warps=8, ), ...
}
@heuristics( { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, }
)
@jit
def _kernel(...)
这里由@jit
修饰的就是kernel函数的定义了。Kernel函数的参数分成几类:
- 调用者给的,如输入输出,维度信息这些。
- 可tuning参数,如block大小。它们与性能相关。可参见:https://triton-lang.org/main/python-api/generated/triton.autotune.html。
- 基于heuristics得到的参数,基于预定义规则得到。如
EVEN_K
代表K维上的元素能否被CTA平分。可参见:https://triton-lang.org/main/python-api/generated/triton.heuristics.html。
上面这个kernel实现可分为三个部分:
- 根据
program_id
确定当前要处理的数据。
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
其中以下几行:
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
告诉编译器数据是align且连续的。这样编译器就可以做一些诸如vectorization的优化。
- 对于每个输出中的tile,在K维上做进行累加。
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: a = tl.load(A) b = tl.load(B) else: k_remaining = K - k * (BLOCK_K * SPLIT_K) _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) if AB_DTYPE is not None: a = a.to(AB_DTYPE) b = b.to(AB_DTYPE) if fp8_fast_accum: acc = tl.dot( a, b, acc, out_dtype=acc_dtype, input_precision=input_precision ) else: acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
- 写回结果。
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1: tl.store(C, acc, mask=mask)
else: tl.atomic_add(C, acc, mask=mask)
这里首先重新计算了输出矩阵中的坐标。其实前面已经算过了,但如果保留到这里就需要耗费register。
然后用mask保证不越界。最后判断如果SPLIT_K
为1,即退化为非SPLIT K的情况,那就直接存结果。否则意味着多个CTA共同计算一个output矩阵的tile,那就需要做累加。又由于这些写的线程属于不同CTA,可能并行执行,因此需要使用atomic add。
Stream-K GEMM
传统的做法是以problem为出发点,将problem size进行切割后,再将它们分到并行的计算单元上。但GPU上的并行计算处理器(SM)数量是固定的。这可能导致wave quantization问题(即子问题的数量不是并行处理器的整数),浪费计算资源。举个最简单的例子,一个任务可分为11个小任务(每个小工作需要1人/天),分给10个人干。但共需要2天才能全部完成,第二个会有9个人是无事可干的。而且这种问题随着GPU的更新,会越来越严重。因为计算单元更强,意味着需要更大的子问题才能“喂饱”它。那子问题的数量自然就会更少,这样就更容易出现wave quantization的现象。
那我们是否可以将子问题切得足够小(如果M, N维不够切就用Split-K),这样就能减少或避免wave quantization。但这样虽然SM利用率可能高了,但性能未必高。Block size过小可能会导致线程内IPL机会变小,另外计算密度变小。如果这样比较抽象的话,可以看一下具体的例子:https://pytorch.org/blog/accelerating-triton/#50-warp-stalling。
Split-K的限制在于它需要K维上的block在CTA间均分。那我们可不可以换个角度,从并行计算处理器为出发点,将子问题按SM数量来切分?这样所有的任务都可以在一个wave中完成,自然也不存在wave quantization的问题。这就是Stream-K的思想。还是用上面的例子,一个任务分为11个小的子任务,将11个子任务按人数分成10份,每个人完成1.1份。Stream-K与Split-K相比,SM利用率更高。而且同步归约次数也更少(不多于SM个数)。
Stream-K算法可参见2023年的论文《Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU》的Algorithm 5
。但单纯使用Stream-K可能导致tile-processing skew问题。因此实际使用时会采用称为"two-tile Stream-K + data parallel"的混合调度方法。它的基本思想是先用Stream-K将那些除不尽的可能导致wave quantization的部分在一个wave中处理掉。这样剩下的部分是能整除的,就可以用传统的高效的DP(data parallel)方法来计算了。便于理解的形象的图可参见论文中的Figure 3 (c)
。
接下来学习如何用Triton来实现Stream-K:https://github.com/triton-lang/triton/issues/1393。代码大体可分三步:首先根据problem计算workload在Stream-K和DP之间怎么划分。然后分别起两个kernel用Stream-K和DP来计算相应的部分。
- 准备阶段
M, K = a.shape
_, N = b.shape
# accumulator types
ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# compute grid (work to do per SM on the first wave)
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
iters_per_tile = triton.cdiv(K, BLK_K)
GROUP_M = 8 # 0 to disable swizzling
total_tiles = total_blocks_M * total_blocks_Nif total_programs_streamk > 0: # Stream-K# last wave may occupy less than total_programs_streamk SMstotal_tiles_streamk = total_tiles % total_programs_streamk# for two-tile Stream-K + data-parallel from original paperif two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:total_tiles_streamk += total_programs_streamk# remaining tiles are computed using classical blockingtotal_blocking_tiles = total_tiles - total_tiles_streamktotal_iters_streamk = total_tiles_streamk * iters_per_tile# iterations related to full wavestotal_full_tiles_streamk = total_iters_streamk // total_programs_streamk# iterations related to last (partial) wavetotal_partial_tiles_streamk = total_iters_streamk % total_programs_streamk...
这里total_programs_streams
是SM数量(同时也是CTA数量)。total_blocks_M
,total_blocks_N
与iters_per_tile
分别是M, N, K三个维度上的block数。total_tiles
是输出矩阵上的tile数量。将这个tile数量除以SM数量,余数部分交给stream k处理。这个数量为total_tiles_streamk
。另外根据paper中的two tiles策略解决SM间的workload imbalance问题,对于这部分每个SM还要再多算一个tile。如共21个tile,4个SM。余数为1,每个SM再多算一个tile,则stream k中,是4个SM计算5个tile。那剩下的tile(数量为total_blocking_tiles
)是由传统DP方法来算。total_full_tiles_streamk
和total_partial_tiles_streamk
分别是Stream-K中每个SM处理的iter次数的最小值,以及余下的iter次数(意味着有些SM需要平摊这些余下的部分)。下面是一个小规模(5个tile,每个tile在K维上2个block)的切分示意图:
可以看到,除不尽的余数部分被分到了前几个SM。
- Stream-K
对应的kernel函数为first_wave
。这里的grid size为SM数量。也就是说每个CTA对应一个SM。SM是可以并行执行的。
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
# allocates locks to sync work accross SMs
locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)
k1 = first_wave[(total_programs_streamk,)](...)
下面看下该kernel的实现:
@triton.jit()
def first_wave( A, B, C, M, N, K, locks,stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,GROUP_M: tl.constexpr,
): pid = tl.program_id(0) start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk)while start_iter < last_iter:end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter)mac_loop(A, B, C,M, N, K, locks,stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,iters_per_tile, start_iter, end_iter, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M, )start_iter = end_iter
它对应论文的Algorithm 5
。这里有点绕的地方在于一个CTA处理可能跨了两个tile,也就是说包含了两个tile的iteration。先以简单的场景(2个CTA处理3个tile,每个tile在K维上分4个block)过一下论文中的算法。
回到代码来。代码中的start_iter
和last_iter
对应论文中iter
和iter_end
,代表当前CTA要处理的iter范围。这里在计算当前CTA的iter范围(即start_iter
与last_iter
)时还要考虑无法整除的情况。
由于这些iter可能跨tile,因此要调多次mac loop完成。代码中的start_iter
和end_iter
对应论文中的 local_iter
和local_iter_end
,代表当前这次mac loop所处理的iter范围。Mac loop的主体部分就是对于给定范围在k维上做reduction。代码实现中将数据的写回部分也放到mac_loop
函数中了,所以与论文中Algorithm 3
中的MacLoop
相比,看起来要复杂一些。
结果写回部分的复杂之处在于一个tile中的iter现在可能由不同的CTA完成了。因此得由一个CTA写,其它的做原子累加。假设某个tile中的iter会由2个CTA(标为CTA0和CTA1)共同完成。
对应地,代码中,CTA 1的mac_loop
中由于条件end_iter % iters_per_tile == 0
成立,会将当前的partial accumulator写入到结果矩阵C
中。然后通过tl.atomic_xchg
通知其它CTA。CTA 0在mac_loop
中会先等待CTA 1完成的信号,然后通过tl.atomic_add
将partial accumulator以原子操作的方式加到结果矩阵。示意图如下:
- DP
这部分起第二个kernel,采用经典的分块策略计算剩余部分。该策略前面已经讲过了,这里不再累述。
k2 = full_tiles[(total_blocking_tiles,)](...)