当前位置: 首页 > news >正文

triton学习笔记7: GEMM相关

这是之前的学习笔记

  1. triton puzzles part1
  2. triton puzzles part2
  3. triton puzzles part3
  4. triton tutorials part1
  5. triton tutorials: part2
  6. triton tutorails: part3

这是triton tutorials里最后一篇关于GEMM的系列了
GEMM的知识可以参考这篇,写的非常详细具体https://zhuanlan.zhihu.com/p/703256080

Group GEMM

from typing import Optional
import torchimport triton
import triton.language as tlDEVICE = triton.runtime.driver.active.get_active_torch_device()def is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma():return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def num_sms():if is_cuda():return torch.cuda.get_device_properties("cuda").multi_processor_countreturn 148@triton.autotune(configs=[triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32,'NUM_SM': 84,}),triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 32,'NUM_SM': 128,}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 64,'BLOCK_SIZE_K': 32,'NUM_SM': 84,}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 64,'BLOCK_SIZE_K': 32,'NUM_SM': 128,}),triton.Config({'BLOCK_SIZE_M': 128,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 64,'NUM_SM': num_sms(),}),triton.Config({'BLOCK_SIZE_M': 64,'BLOCK_SIZE_N': 128,'BLOCK_SIZE_K': 64,'NUM_SM': num_sms(),}),],key=['group_size'],
)
@triton.jit
def grouped_matmul_kernel(# device tensor of matrices pointersgroup_a_ptrs,group_b_ptrs,group_c_ptrs,# device tensor of gemm sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemmgroup_gemm_sizes,# device tensor of leading dimension sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemmg_lds,# number of gemmsgroup_size,# number of virtual SMNUM_SM: tl.constexpr,# tile sizesBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,
):tile_idx = tl.program_id(0)last_problem_end = 0for g in range(group_size):# get the gemm size of the current problemgm = tl.load(group_gemm_sizes + g * 3)gn = tl.load(group_gemm_sizes + g * 3 + 1)gk = tl.load(group_gemm_sizes + g * 3 + 2)num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)num_tiles = num_m_tiles * num_n_tiles# iterate through the tiles in the current gemm problemwhile (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):# pick up a tile from the current gemm problemk = gklda = tl.load(g_lds + g * 3)ldb = tl.load(g_lds + g * 3 + 1)ldc = tl.load(g_lds + g * 3 + 2)a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))# figure out tile coordinatestile_idx_in_gemm = tile_idx - last_problem_endtile_m_idx = tile_idx_in_gemm // num_n_tilestile_n_idx = tile_idx_in_gemm % num_n_tiles# do regular gemm hereoffs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)offs_k = tl.arange(0, BLOCK_SIZE_K)a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):# hint to Triton compiler to do proper loop pipeliningtl.multiple_of(a_ptrs, [16, 16])tl.multiple_of(b_ptrs, [16, 16])# assume full tile for nowa = tl.load(a_ptrs)b = tl.load(b_ptrs)accumulator += tl.dot(a, b)a_ptrs += BLOCK_SIZE_Kb_ptrs += BLOCK_SIZE_K * ldbc = accumulator.to(tl.float16)offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]# assumes full tile for nowtl.store(c_ptrs, c)# go to the next tile by advancing NUM_SMtile_idx += NUM_SM# get ready to go to the next gemm problemlast_problem_end = last_problem_end + num_tilesdef group_gemm_fn(group_A, group_B):assert len(group_A) == len(group_B)group_size = len(group_A)A_addrs = []B_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = group_A[i]B = group_B[i]assert A.shape[1] == B.shape[0]M, K = A.shapeK, N = B.shapeC = torch.empty((M, N), device=DEVICE, dtype=A.dtype)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]# note these are device tensorsd_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)# we use a fixed number of CTA, and it's auto-tunablegrid = lambda META: (META['NUM_SM'], )grouped_matmul_kernel[grid](d_a_ptrs,d_b_ptrs,d_c_ptrs,d_g_sizes,d_g_lds,group_size,)return group_Ctma_configs = [triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, 'BLOCK_SIZE_K' : BK}, num_stages=s, num_warps=w) \for BM in [128]\for BN in [128, 256]\for BK in [64, 128]\for s in ([3, 4])\for w in [4, 8]\
]@triton.autotune(tma_configs,key=['group_a_ptrs', 'group_b_ptrs', 'gropup_c_ptrs', 'group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(# device tensor of matrices pointersgroup_a_ptrs,group_b_ptrs,group_c_ptrs,# device tensor of gemm sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemmgroup_gemm_sizes,# device tensor of leading dimension sizes. its shape is [group_size, 3]# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemmg_lds,# number of gemmsgroup_size,# number of virtual SMNUM_SM: tl.constexpr,# tile sizesBLOCK_SIZE_M: tl.constexpr,BLOCK_SIZE_N: tl.constexpr,BLOCK_SIZE_K: tl.constexpr,# is the output FP8 or FP16FP8: tl.constexpr,
):dtype = tl.float8e4nv if FP8 else tl.float16tile_idx = tl.program_id(0)last_problem_end = 0for g in range(group_size):# get the gemm size of the current problemgm = tl.load(group_gemm_sizes + g * 3)gn = tl.load(group_gemm_sizes + g * 3 + 1)gk = tl.load(group_gemm_sizes + g * 3 + 2)num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)num_tiles = num_m_tiles * num_n_tilesif tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:# pick up a tile from the current gemm problemlda = tl.load(g_lds + g * 3)ldb = tl.load(g_lds + g * 3 + 1)ldc = tl.load(g_lds + g * 3 + 2)a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype))b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype))c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype))a_desc = tl.make_tensor_descriptor(a_ptr,shape=[gm, gk],strides=[lda, 1],block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],)b_desc = tl.make_tensor_descriptor(b_ptr,shape=[gn, gk],strides=[ldb, 1],block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],)c_desc = tl.make_tensor_descriptor(c_ptr,shape=[gm, gn],strides=[ldc, 1],block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],)# iterate through the tiles in the current gemm problemwhile (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles):k = gk# figure out tile coordinatestile_idx_in_gemm = tile_idx - last_problem_endtile_m_idx = tile_idx_in_gemm // num_n_tilestile_n_idx = tile_idx_in_gemm % num_n_tiles# do regular gemm hereoffs_am = tile_m_idx * BLOCK_SIZE_Moffs_bn = tile_n_idx * BLOCK_SIZE_Naccumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):a = a_desc.load([offs_am, kk * BLOCK_SIZE_K])b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K])accumulator += tl.dot(a, b.T)offs_cm = tile_m_idx * BLOCK_SIZE_Moffs_cn = tile_n_idx * BLOCK_SIZE_Nc = accumulator.to(dtype)c_desc.store([offs_cm, offs_cn], c)# go to the next tile by advancing NUM_SMtile_idx += NUM_SM# get ready to go to the next gemm problemlast_problem_end = last_problem_end + num_tilesdef group_gemm_tma_fn(group_A, group_B):assert supports_tma()assert len(group_A) == len(group_B)group_size = len(group_A)A_addrs = []B_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = group_A[i]B = group_B[i]assert A.shape[1] == B.shape[1]M, K = A.shapeN, K = B.shapeC = torch.empty((M, N), device=DEVICE, dtype=A.dtype)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]# note these are device tensorsd_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)# we use a fixed number of CTA, and it's auto-tunable# TMA descriptors require a global memory allocationdef alloc_fn(size: int, alignment: int, stream: Optional[int]):return torch.empty(size, device="cuda", dtype=torch.int8)triton.set_allocator(alloc_fn)grid = lambda META: (META['NUM_SM'], )grouped_matmul_tma_kernel[grid](d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size,FP8=torch.float8_e4m3fn == group_A[0].dtype, NUM_SM=num_sms())return group_Cgroup_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
group_A = []
group_B = []
group_B_T = []
assert len(group_m) == len(group_n)
assert len(group_n) == len(group_k)
group_size = len(group_m)
for i in range(group_size):M = group_m[i]N = group_n[i]K = group_k[i]A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=1e-2)if supports_tma():tri_tma_out = group_gemm_tma_fn(group_A, group_B_T)for i in range(group_size):assert torch.allclose(ref_out[i], tri_tma_out[i], atol=1e-2, rtol=1e-2)# only launch the kernel, no tensor preparation here to remove all overhead
def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size):grid = lambda META: (META['NUM_SM'], )grouped_matmul_kernel[grid](a_ptrs,b_ptrs,c_ptrs,sizes,lds,group_size,)def triton_tma_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, dtype):grid = lambda META: (META['NUM_SM'], )grouped_matmul_tma_kernel[grid](a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, FP8=torch.float8_e4m3fn == dtype,NUM_SM=num_sms())def torch_perf_fn(group_A, group_B):for a, b in zip(group_A, group_B):torch.matmul(a, b)@triton.testing.perf_report(triton.testing.Benchmark(# argument names to use as an x-axis for the plotx_names=['N'],x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`line_arg='provider',# argument name whose value corresponds to a different line in the plot# possible values for `line_arg``line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []),# label name for the linesline_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []),# line stylesstyles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []),ylabel="runtime(ms)",  # label name for the y-axisplot_name="group-gemm-performance",# name for the plot. Used also as a file name for saving the plot.args={},))
def benchmark_square_matrices(N, provider):group_size = 4group_A = []group_B = []group_B_T = []A_addrs = []B_addrs = []B_T_addrs = []C_addrs = []g_sizes = []g_lds = []group_C = []for i in range(group_size):A = torch.rand((N, N), device=DEVICE, dtype=torch.float16)B = torch.rand((N, N), device=DEVICE, dtype=torch.float16)C = torch.empty((N, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())B_T_addrs.append(B_T.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [N, N, N]g_lds += [N, N, N]d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)quantiles = [0.5, 0.2, 0.8]if provider == 'cublas':ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)if provider == 'triton-tma':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, dtype=torch.float16), quantiles=quantiles)return ms, max_ms, min_ms@triton.testing.perf_report(triton.testing.Benchmark(# argument names to use as an x-axis for the plotx_names=['M'],x_vals=[2**i for i in range(7, 11)],  # different possible values for `x_name`line_arg='provider',# argument name whose value corresponds to a different line in the plot# possible values for `line_arg``line_vals=['cublas', 'triton'] + (['triton-tma'] if supports_tma() else []),# label name for the linesline_names=["cuBLAS", "Triton"] + (['Triton + TMA'] if supports_tma() else []),# line stylesstyles=[('green', '-'), ('blue', '-')] + ([('red', '-')] if supports_tma() else []),ylabel="runtime(ms)",  # label name for the y-axisplot_name="group-gemm-performance-m-8192-k-8192",# name for the plot. Used also as a file name for saving the plot.args={},))
def benchmark_batches(M, provider):N = 8192K = 8192group_size = 4group_A = []group_B = []group_B_T = []A_addrs = []B_addrs = []B_T_addrs = []C_addrs = []g_sizes = []g_lds = []g_T_lds = []group_C = []for i in range(group_size):A = torch.rand((M, K), device=DEVICE, dtype=torch.float16)B = torch.rand((K, N), device=DEVICE, dtype=torch.float16)C = torch.empty((M, N), device=DEVICE, dtype=torch.float16)B_T = B.T.contiguous()group_A.append(A)group_B.append(B)group_B_T.append(B_T)group_C.append(C)A_addrs.append(A.data_ptr())B_addrs.append(B.data_ptr())B_T_addrs.append(B_T.data_ptr())C_addrs.append(C.data_ptr())g_sizes += [M, N, K]g_lds += [A.stride(0), B.stride(0), C.stride(0)]g_T_lds += [A.stride(0), B_T.stride(0), C.stride(0)]d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)d_b_t_ptrs = torch.tensor(B_T_addrs, device=DEVICE)d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)d_g_t_lds = torch.tensor(g_T_lds, dtype=torch.int32, device=DEVICE)quantiles = [0.5, 0.2, 0.8]if provider == 'cublas':ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles)if provider == 'triton':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles)if provider == 'triton-tma':ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_tma_perf_fn(d_a_ptrs, d_b_t_ptrs, d_c_ptrs, d_g_sizes, d_g_t_lds, group_size, dtype=torch.float16), quantiles=quantiles)return ms, max_ms, min_msbenchmark_square_matrices.run(show_plots=True, print_data=True)
benchmark_batches.run(show_plots=True, print_data=True)

1. 导入必要的模块和工具函数

from typing import Optional
import torchimport triton
import triton.language as tl

导入了 torch(PyTorch 框架)、triton(用于编写和优化 GPU 内核)及其语言模块 tl

2. 检查是否使用 CUDA 后端及设备特性

def is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_tma():return is_cuda() and torch.cuda.get_device_capability()[0] >= 9def num_sms():if is_cuda():return torch.cuda.get_device_properties("cuda").multi_processor_countreturn 148

定义了辅助函数来判断是否使用 CUDA 后端、是否支持 TMA(Tensor Memory Aliasing)以及获取设备的 SM(Streaming Multiprocessor)数量。

3. 定义分组矩阵乘法内核

@triton.autotune(configs=[# ... 配置列表 ...],key=['group_size'],
)
@triton.jit
def grouped_matmul_kernel(# ... 参数列表 ...
):# 内核实现逻辑

利用 triton.autotunetriton.jit 装饰器定义了一个自动调优的内核函数,用于执行分组矩阵乘法操作。

具体实现逻辑:
  • 遍历每个分组,计算每个矩阵对的大小(M、N、K)和内存布局(leading dimensions)。
  • 将矩阵切片为 BLOCK_SIZE_M × BLOCK_SIZE_N 大小的块(tile)。
  • 通过循环迭代每个块,从中加载数据并执行矩阵乘法计算。
  • 通过 tl.loadtl.store 操作与设备内存交互。

4. 定义分组矩阵乘法的上层函数

def group_gemm_fn(group_A, group_B):# ... 函数实现 ...return group_C

用于调用上述内核函数执行分组矩阵乘法的上层函数。

具体实现逻辑:
  • 验证输入矩阵对数量一致。
  • 遍历每个矩阵对,准备分组信息(矩阵大小、内存布局等)。
  • 创建输出矩阵并收集所有设备指针。
  • 转换为设备张量并调用内核函数。

5. 实现基于 TMA 的分组矩阵乘法版本(若支持 TMA)

@triton.autotune(tma_configs,key=['group_size'],
)
@triton.jit
def grouped_matmul_tma_kernel(# ... 参数列表 ...
):# TMA 版本的内核实现逻辑

利用 TMA 技术实现的高性能分组矩阵乘法内核,并通过 triton.set_allocator 设置了专门的内存分配函数以支持 TMA。

6. 测试代码和性能基准

group_m = [1024, 512, 256, 128]
group_n = [1024, 512, 256, 128]
group_k = [1024, 512, 256, 128]
# ... 测试数据准备 ...tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
# ... 结果验证 ...# ... 性能基准测试代码 ...
  • 准备了测试数据并调用上述函数进行验证。
  • 使用 torch.testing.assert_allclose 验证结果一致性。
  • 最后,使用 triton.testing.perf_reporttriton.testing.Benchmark 定义性能测试函数并运行,生成性能报告。

Persistent Matmul

这个就不贴代码了,我刚看到这块内容的时候也有一些好奇,主要阐述一下和普通矩阵乘法的区别:

这两段代码分别实现了非持久化矩阵乘法和持久化矩阵乘法,它们有以下区别:

    • 持久化矩阵乘法(代码2):通过循环和多个程序 ID(start_pid)的使用,实现更细粒度的任务分配。一个程序可以处理多个子矩阵的乘法任务,从而提高资源利用率。
  1. 资源利用率不同
  • 非持久化矩阵乘法:由于每个程序处理一个子矩阵,可能存在计算资源未充分利用的情况,特别是在矩阵规模较大时。
  • 持久化矩阵乘法:通过循环多次调度程序(使用tl.range),可以更有效地利用 GPU 的计算资源,尤其是在大规模矩阵运算中。
  1. 数据流和存储操作的优化程度不同
  • 非持久化矩阵乘法:在计算完成后,直接存储结果到目标位置。这种设计简单直接,但可能在大规模运算中导致存储操作的不连续性。
  • 持久化矩阵乘法:通过延迟存储操作(tile_id_c += NUM_SMS后才进行存储),允许计算和存储操作的重叠,从而提高整体效率。
  1. 硬件调度和性能优化的不同
  • 非持久化矩阵乘法:主要依赖于 GPU 的自动调度机制,对大规模矩阵的适应性可能较差。
  • 持久化矩阵乘法:通过显式管理程序调度和计算任务,减少了因硬件调度机制导致的延迟,特别是在大规模矩阵运算中能显著提高性能。
方面非持久化矩阵乘法持久化矩阵乘法
任务分配方式一次性分配,线程块处理单个任务循环分配,线程块处理多个任务
线程块调度简单直接,线程块独立处理任务复杂,线程块持续从任务队列获取任务
资源利用率较低,处理多个小任务时易出现空闲较高,充分利用 GPU 资源
适用场景单个大矩阵乘法任务多个小矩阵乘法任务
实现复杂度较低,逻辑简单较高,需要管理任务队列和调度

综上所述,持久化矩阵乘法比非持久化矩阵乘法在大规模矩阵运算中更高效,因为它通过更细致的调度和资源管理,充分利用了 GPU 的计算资源,降低了存储操作的延迟。

Block Scaled Matrix Multiplication

CUDA 设备若支持 PTX 8.7 及更高版本,便能利用块缩放矩阵乘法指令。为确保在张量核心矩阵乘法的快速内循环中低延迟访问这些缩放因子,须保证块缩放因子在内存中以连续布局存储,与访问模式相符。

块缩放矩阵乘法的张量核心指令会计算如下乘积:

C = ( A × s c a l e _ a ) @ ( B × s c a l e _ b ) C = (A \times scale\_a) @ (B \times scale\_b) C=(A×scale_a)@(B×scale_b)

其中,( s c a l e a scale_a scalea ) 和 ( s c a l e b scale_b scaleb ) 分别是矩阵 A 和 B 的块缩放因子。在块缩放矩阵乘法下,每个缩放因子会沿着各自的 K 轴广播并乘以矩阵 A 和 B 的元素向量。此处,A 和 B 中每个缩放因子广播的元素数量被称为向量大小(VEC_SIZE)。

在行主序的线性布局中,缩放因子的形状为:

( M , K / / V E C _ S I Z E ) 和  ( N , K / / V E C _ S i z e ) (M, K // VEC\_SIZE) \text{ 和 } (N, K // VEC\_Size) (M,K//VEC_SIZE)  (N,K//VEC_Size)

不过,为避免非连续内存访问,将缩放因子存储为打包的块布局更为有利。对于左侧矩阵(LHS),布局如下:

( M 32 × 4 , K V E C _ S I Z E × 4 , 32 , 4 , 4 ) \left( \frac{M}{32 \times 4}, \frac{K}{VEC\_SIZE \times 4}, 32, 4, 4 \right) (32×4M,VEC_SIZE×4K,32,4,4)

如此一来,在 K 块的快速内循环中,每个张量核心 MMA 可连续访问 M 轴上 128 行的缩放因子块,对应矩阵 A 的每个 BLOCK_M x BLOCK_K 子块。

为符合 Triton 语言对 dot_scaled 的语义要求,缩放因子需按上述 5D 布局准备,但随后需逻辑转置并重塑为张量点积期望的 2D 布局。

import argparseimport torch
import triton
import triton.language as tl
import triton.profiler as proton
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensordef is_cuda():return triton.runtime.driver.active.get_current_target().backend == "cuda"def supports_block_scaling():return is_cuda() and torch.cuda.get_device_capability()[0] == 10def _matmul_launch_metadata(grid, kernel, args):ret = {}M, N, K = args["M"], args["N"], args["K"]kernel_name = kernel.nameif "ELEM_PER_BYTE_A" and "ELEM_PER_BYTE_B" and "VEC_SIZE" in args:if args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 1:kernel_name += "_mxfp8"elif args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 2:kernel_name += "_mixed"elif args["ELEM_PER_BYTE_A"] == 2 and args["ELEM_PER_BYTE_B"] == 2:if args["VEC_SIZE"] == 16:kernel_name += "_nvfp4"elif args["VEC_SIZE"] == 32:kernel_name += "_mxfp4"ret["name"] = f"{kernel_name} [M={M}, N={N}, K={K}]"ret["flops"] = 2. * M * N * Kreturn ret@triton.jit(launch_metadata=_matmul_launch_metadata)
def block_scaled_matmul_kernel(  #a_desc, a_scale,  #b_desc, b_scale,  #c_desc,  #M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,  #stride_sk: tl.constexpr, stride_sb: tl.constexpr, stride_sc: tl.constexpr, stride_sd: tl.constexpr,output_type: tl.constexpr,  #ELEM_PER_BYTE_A: tl.constexpr,  #ELEM_PER_BYTE_B: tl.constexpr,  #VEC_SIZE: tl.constexpr,  #BLOCK_M: tl.constexpr,  #BLOCK_N: tl.constexpr,  #BLOCK_K: tl.constexpr,  #NUM_STAGES: tl.constexpr,  #USE_2D_SCALE_LOAD: tl.constexpr):  #if output_type == 0:output_dtype = tl.float32elif output_type == 1:output_dtype = tl.float16elif output_type == 2:output_dtype = tl.float8e4nvpid = tl.program_id(axis=0)num_pid_m = tl.cdiv(M, BLOCK_M)pid_m = pid % num_pid_mpid_n = pid // num_pid_moffs_am = pid_m * BLOCK_Moffs_bn = pid_n * BLOCK_Noffs_k_a = 0offs_k_b = 0## block scale offsetsoffs_sm = (pid_m * (BLOCK_M // 128) + tl.arange(0, BLOCK_M // 128)) % Moffs_sn = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) % NMIXED_PREC: tl.constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2# For now it is recommended to use 2D scale loads for better performance.# In the future we will bring additional optimizations to either allow 5D loads,# the use of TMAs for scale factors, or both.if USE_2D_SCALE_LOAD:offs_inner = tl.arange(0, (BLOCK_K // VEC_SIZE // 4) * 32 * 4 * 4)a_scale_ptr = a_scale + offs_sm[:, None] * stride_sk + offs_inner[None, :]b_scale_ptr = b_scale + offs_sn[:, None] * stride_sk + offs_inner[None, :]else:offs_sk = tl.arange(0, (BLOCK_K // VEC_SIZE // 4))# MN spatial offsets for 32 element blockingoffs_sc = tl.arange(0, 32)# offsets for both scale factor column ID (along K)# and spatial block column ID (along MN)offs_sd = tl.arange(0, 4)a_scale_ptr = a_scale + (offs_sm[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] *stride_sb + offs_sc[None, None, :, None, None] * stride_sc +offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :])b_scale_ptr = b_scale + (offs_sn[:, None, None, None, None] * stride_sk + offs_sk[None, :, None, None, None] *stride_sb + offs_sc[None, None, :, None, None] * stride_sc +offs_sd[None, None, None, :, None] * stride_sd + offs_sd[None, None, None, None, :])accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):a = a_desc.load([offs_am, offs_k_a])b = b_desc.load([offs_bn, offs_k_b])scale_a = tl.load(a_scale_ptr)scale_b = tl.load(b_scale_ptr)if USE_2D_SCALE_LOAD:scale_a = scale_a.reshape(BLOCK_M // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)scale_b = scale_b.reshape(BLOCK_N // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)scale_a = scale_a.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)scale_b = scale_b.trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)if MIXED_PREC:accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e2m1", accumulator)elif ELEM_PER_BYTE_A == 2 and ELEM_PER_BYTE_B == 2:accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)else:accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e4m3", accumulator)offs_k_a += BLOCK_K // ELEM_PER_BYTE_Aoffs_k_b += BLOCK_K // ELEM_PER_BYTE_Ba_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sbb_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sbc_desc.store([offs_am, offs_bn], accumulator.to(output_dtype))def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, configs):output = torch.empty((M, N), dtype=dtype_dst, device="cuda")if dtype_dst == torch.float32:dtype_dst = 0elif dtype_dst == torch.float16:dtype_dst = 1elif dtype_dst == torch.float8_e4m3fn:dtype_dst = 2else:raise ValueError(f"Unsupported dtype: {dtype_dst}")BLOCK_M = configs["BLOCK_SIZE_M"]BLOCK_N = configs["BLOCK_SIZE_N"]c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc, b_scale, c_desc, M, N, K, a_scale.stride(0),a_scale.stride(1), a_scale.stride(2), a_scale.stride(3), dtype_dst,configs["ELEM_PER_BYTE_A"], configs["ELEM_PER_BYTE_B"], configs["VEC_SIZE"],configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"],configs["num_stages"], USE_2D_SCALE_LOAD=True)return outputdef initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference=False):BLOCK_M = 128BLOCK_N = 256BLOCK_K = 256 if "fp4" in block_scale_type else 128VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8", "mixed"], f"Invalid block scale type: {block_scale_type}"ELEM_PER_BYTE_A = 2 if "fp4" in block_scale_type else 1ELEM_PER_BYTE_B = 1 if block_scale_type == "mxfp8" else 2device = "cuda"a_ref = MXFP4Tensor(size=(M, K), device=device).random()# Similar to Hopper's wgmma symmetric fp8 instruction, the RHS is expected# to be in col-major layout for Blackwell's tcgen05.mma when using fp4 operands.# To conform to the expected semantics of tl.dot_scaled, (M, K) x (K, N),# the data is generated in col-major layout, packed along K for fp4, and then# logically transposed. Note that if one operand is of fp8 precision, unlike Hopper,# Blackwell supports both row-major and col-major layouts for the RHS matrix.# For the mixed-precision case, the fp4 RHS can be either in row or col-major layout.# But for performance reason, it is recommended to use col-major layout. If TMA is used# for the fp4 RHS operand load in mixed-precision dot, as in this tutorial, it must be# in col-major layout.b_ref = MXFP4Tensor(size=(N, K), device=device).random()if block_scale_type in ["mxfp8", "mixed"]:a_ref = a_ref.to(torch.float32)a = a_ref.to(torch.float8_e4m3fn)else:# Pack two fp4 elements per byte along Ka = a_ref.to_packed_tensor(dim=1)if block_scale_type == "mxfp8":b_ref = b_ref.to(torch.float32)b = b_ref.to(torch.float8_e4m3fn)else:b = b_ref.to_packed_tensor(dim=1)b_ref = b_ref.to(torch.float32).Ta_desc = TensorDescriptor.from_tensor(a, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A])if block_scale_type == "mixed":b_desc = TensorDescriptor(b,shape=[N, K // ELEM_PER_BYTE_B],strides=[K // ELEM_PER_BYTE_B, 1],block_shape=[BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],)else:b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B])epsilon = 1e-8a_scale = torch.rand((M // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilonb_scale = torch.rand((N // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilonif block_scale_type == "nvfp4":a_scale = a_scale.to(torch.float8_e4m3fn)b_scale = b_scale.to(torch.float8_e4m3fn)a_scale_ref = a_scaleb_scale_ref = b_scaleelif block_scale_type in ["mxfp4", "mxfp8", "mixed"]:a_scale_ref = MXScaleTensor(a_scale)b_scale_ref = MXScaleTensor(b_scale)a_scale = a_scale_ref.datab_scale = b_scale_ref.datareference = Noneif compute_reference:a_scale_ref = a_scale_ref.to(torch.float32)b_scale_ref = b_scale_ref.to(torch.float32)def unpack_scale(packed):num_chunk_m, num_chunk_k, _, _, _ = packed.shapereturn packed.permute(0, 3, 2, 1, 4).reshape(num_chunk_m * 128, num_chunk_k * 4).contiguous()a_scale_ref = unpack_scale(a_scale_ref).repeat_interleave(VEC_SIZE, dim=1)[:M, :K]b_scale_ref = unpack_scale(b_scale_ref).repeat_interleave(VEC_SIZE, dim=1).T.contiguous()[:K, :N]reference = torch.matmul(a_ref.to(torch.float32) * a_scale_ref, b_ref * b_scale_ref)configs = {"BLOCK_SIZE_M": BLOCK_M,"BLOCK_SIZE_N": BLOCK_N,"BLOCK_SIZE_K": BLOCK_K,"num_stages": 4,"ELEM_PER_BYTE_A": ELEM_PER_BYTE_A,"ELEM_PER_BYTE_B": ELEM_PER_BYTE_B,"VEC_SIZE": VEC_SIZE,}return a_desc, a_scale, b_desc, b_scale, configs, referencedef validate_block_scaled(M, N, K, block_scale_type="nvfp4"):def alloc_fn(size: int, align: int, _):return torch.empty(size, dtype=torch.int8, device="cuda")if block_scale_type == "mixed":# This is needed for TMA with the descriptor created on the device.# TMA load for mixed-precision fp4 is supported only by device TMA.triton.set_allocator(alloc_fn)a_desc, a_scale, b_desc, b_scale, configs, reference = initialize_block_scaled(M, N, K, block_scale_type,compute_reference=True)output = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)torch.testing.assert_close(reference, output.to(torch.float32), atol=1e-3, rtol=1e-3)print(f"✅ (pass {block_scale_type})")def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):assert K % 128 == 0M = 8192N = 8192print(f"Problem Shape = {M}x{N}x{K}")a_desc, a_scale, b_desc, b_scale, configs, _ = initialize_block_scaled(M, N, K, block_scale_type,compute_reference=False)_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)proton.activate(0)for _ in range(reps):_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)proton.deactivate(0)print("Done benchmarking")def show_profile(profile_name):import triton.profiler.viewer as proton_viewermetric_names = ["time/ms"]metric_names = ["tflop/s"] + metric_namesfile_name = f"{profile_name}.hatchet"tree, metrics = proton_viewer.parse(metric_names, file_name)proton_viewer.print_tree(tree, metrics)if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("-K", type=int, required=False, default=512)parser.add_argument("--K_range", type=int, nargs=2)parser.add_argument("--K_step", type=int, default=512)parser.add_argument("--bench", action="store_true", default=True)parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8", "mixed"], default="nvfp4")args = parser.parse_args()if not supports_block_scaling():print("⛔ This example requires GPU support for block scaled matmul")else:if args.K and args.K_range is None:args.K_range = [args.K, args.K]args.K_step = 1  # doesn't matter as long as it's not 0torch.manual_seed(42)validate_block_scaled(8192, 8192, 8192, block_scale_type=args.format)if args.bench:proton.start("block_scaled_matmul", hook="triton")proton.deactivate(0)  # Skip argument creationfor K in range(args.K_range[0], args.K_range[1] + 1, args.K_step):bench_block_scaled(K, reps=10000, block_scale_type=args.format)proton.finalize()show_profile("block_scaled_matmul")

总结

  1. 通过了一些练习学习了triton的基础语法和一些gpu的知识,还有些不懂的继续学习之后再回忆理解一下

  2. 会继续跟进这块方向的知识,构建起完整的知识树

Reference

  1. 从啥也不会到CUDA GEMM优化
  2. Tutorials — Triton documentation
http://www.xdnf.cn/news/935191.html

相关文章:

  • uniapp跳转到webview组件的时候,要注意:移除所有不可见字符(包括零宽空格)
  • Linux系统之grub-mkrescue详解
  • vue.js not detected解决方法
  • Oracle实用参考(13)——Oracle for Linux物理DG环境搭建(2)
  • 第四篇:服务商(工人端)-02服务商入驻审核
  • SCADA|RESTful学习,Apipost通过GET获取KingSCADA实时数据
  • 软件测试—学习Day11
  • HTTP 重定向详解
  • Vulkan 3D Tiles渲染器开发笔记1-脚手架搭建
  • Linux nano命令的基本使用
  • 代码随想录算法训练营第60期第六十天打卡
  • 十一(2) 类的实例化
  • 打卡第48天
  • 系统思考:跳出症状看全局
  • 第35周综合就业指南
  • 深入剖析AI大模型:用神经网络构建医疗影像辅助诊断系统
  • Compose笔记(二十六)--DatePicker
  • LeetCode 1723: 完成所有工作的最短时间
  • 大数据+智能零售:数字化变革下的“智慧新零售”密码
  • LLMs 系列科普文(5)
  • 大模型外挂MCP教程(8): 飞算JavaAI智能分析搭建自己的MCP Server
  • godot小白入门前的一些前置知识了解
  • 深入了解linux系统—— 共享内存
  • BERT
  • 【数据结构】图论基石:最小生成树(MST)实战精解与Prim/Kruskal算法详解
  • LLMs之PE:system-prompts-and-models-of-ai-tools的简介、使用方法、案例应用之详细攻略
  • 掌握 HTTP 请求:理解 cURL GET 语法
  • 基于dify的营养分析工作流:3分钟生成个人营养分析报告
  • 【HarmonyOS5】掌握UIAbility启动模式:Singleton、Specified、Multiton
  • 探究 Java SPI 原理与实战_打造高扩展性的应用架构