当前位置: 首页 > backend >正文

精通 triton 使用 MLIR 的源码逻辑 - 第002节:再掌握一些 triton 语法 — 通过 02 softmax

1. 热身预备向量的  softmax 函数


Softmax 函数是深度学习和机器学习中广泛使用的激活函数,主要用于多分类问题,将输入向量转换为概率分布,使得所有输出值的和为 1。

1.1. Softmax 函数原理

设  X=[x_1,x_2, \dots , x_n]  则其 Softmax 函数的数学定义为,

                \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}}

Softmax 函数的性质

归一化:输出值在 [0, 1] 之间,且总和为 1,适合概率解释;

单调性:较大的输入值对应较大的输出概率;

可导性:便于反向传播优化(梯度计算);

1.2. Softmax 计算示例

       假设输入向量为 x = [2.0, 1.0, 0.1]

  计算步骤:

        step1  计算指数

                e^{2.0} \approx 7.389, \quad e^{1.0} \approx 2.718, \quad e^{0.1} \approx 1.105

        step2  求和

                \sum = 7.389 + 2.718 + 1.105 \approx 11.212

        step3  归一化:

                \text{Softmax}(x_1) = \frac{7.389}{11.212} \approx 0.659 \\ \\ \text{Softmax}(x_2) = \frac{2.718}{11.212} \approx 0.242 \\ \\ \text{Softmax}(x_3) = \frac{1.105}{11.212} \approx 0.099

  最终输出:

                \text{Softmax}(x) \approx [0.659, 0.242, 0.099]

1.3. 矩阵逐行 Softmax 计算


在深度学习中,Softmax 通常用于矩阵(如神经网络的输出层)。通常每行代表一个样本的不同类别得分。假设输入矩阵:

X = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ \end{bmatrix}

       可以看出矩阵 X 的第一行与第二行成比例关系,所以,可以期待其对应元素的概率值也应该相等。

计算过程如下,

    逐行计算 Softmax:

        第一行 [1, 2, 3]

               \text{Softmax}([1, 2, 3]) = \left[ \frac{e^1}{e^1+e^2+e^3}, \frac{e^2}{e^1+e^2+e^3}, \frac{e^3}{e^1+e^2+e^3} \right] \approx [0.090, 0.245, 0.665] 

        第二行 [4, 5, 6]

               \text{Softmax}([4, 5, 6]) \approx [0.090, 0.245, 0.665]

    最终输出矩阵:

               \text{Softmax}(X) \approx \begin{bmatrix} 0.090 & 0.245 & 0.665 \\ 0.090 & 0.245 & 0.665 \\ \end{bmatrix}

1.4. softmax 的数值稳定性优化


由于指数计算可能导致数值溢出(exp(x) 在 x 较大时爆炸),通常采用 Log-Softmax 或 减去最大值的技巧,计算结果不变:

\text{Softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}

计算示例:

            x = [1000, 1001, 1002]

直接计算 exp(1000) 会溢出,但减去 max(x)=1002 后:

            x' = [-2, -1, 0]

再计算 Softmax:

            \text{Softmax}(x') = [ \frac{e^{-2}}{e^{-2}+e^{-1}+e^0}, \frac{e^{-1}}{e^{-2}+e^{-1}+e^0}, \frac{e^0}{...} ] \approx [0.090, 0.245, 0.665]

1.5. 通过 python 来验证上述理论

用 Python 中的 softmax验证上述计算:

cpu 版本的 stable softmax

hello_softmax.py  :


import numpy as npdef stable_softmax(x):x = x - np.max(x, axis=-1, keepdims=True)exp_x = np.exp(x)return exp_x / np.sum(exp_x, axis=-1, keepdims=True)def softmax(x):exp_x = np.exp(x)return exp_x / np.sum(exp_x, axis=-1, keepdims=True)X = np.array([[1, 2, 3], [4, 5, 6]])
print('\n softmzx X s:')
print(softmax(X))y = np.array([500, 501, 502])
print('\n softmax y s:')
print(stable_softmax(y))print('\n softmax y  :')
print(softmax(y))z = np.array([1000, 1002, 1002])
print('\n softmax z s:')
print(stable_softmax(z))print('\nsoftmax z  :')
print(softmax(z))

实验中可以发现,在普通的 python softmax 中,处理 [1000, 1001, 1002] 时遇到了溢出,无法顺利计算数学意义上的概率分布。

gpu 初级版本 stable softmax

naive_softmax.py : 

import torchDEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")def naive_softmax(x):x_max = x.max(dim=1)[0]z = x - x_max[:, None]numerator = torch.exp(z)denominator = numerator.sum(dim=1)ret = numerator / denominator[:, None]return rettorch.manual_seed(0)
x = torch.randn(8, 8, device=DEVICE)
y_naive  = naive_softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_naive, y_torch), (y_naive, y_torch)print('y_naive =')
print(y_naive)

做点语法解释,

(1.) torch.max() 函数的基本用法
torch.max(input, dim) 函数有两个主要功能:

返回指定维度上的最大值,同时返回最大值对应的索引
当指定 dim 参数时,它会返回一个包含两个张量的元组
第一个张量是最大值(values)
第二个张量是最大值的索引(indices)

(2.) 语法解析:x.max(dim=1)[0]

x_max = x.max(dim=1)[0]

dim=1 表示沿着第1维度(列方向)计算最大值
[0] 表示取返回元组的第一个元素(最大值张量)

运行: 

1.6. 稳定版 softmax 的简单的理论证明

Softmax 函数的定义

    \text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}}

如果对输入向量 xx 的每个元素减去同一个常数 cc,Softmax 结果不变:

    \text{Softmax}(x_i - c) = \frac{e^{x_i - c}}{\sum_{j=1}^n e^{x_j - c}} = \frac{e^{x_i} \cdot e^{-c}}{e^{-c} \cdot \sum_{j=1}^n e^{x_j}} = \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} = \text{Softmax}(x_i)
结论分析:

        (1.) 减去任意常数 c不影响 Softmax 的输出;

        (2.) 通常选择 c = \max(x),这样可以避免数值溢出(因为最大的指数项变为 e^0 = 1.0,其他元素不大于 1.0 )。

2. Triton 实现 stable softmax

       是将 triton tutorial 02-fused-softmax.py 简化到 70行左右:

triton_stable_softmax.py :


import torchimport triton
import triton.language as tl
from triton.runtime import driverDEVICE = triton.runtime.driver.active.get_active_torch_device()@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,num_stages: tl.constexpr):row_start = tl.program_id(0)row_step = tl.num_programs(0)for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):row_start_ptr = input_ptr + row_idx * input_row_stridecol_offsets = tl.arange(0, BLOCK_SIZE)input_ptrs = row_start_ptr + col_offsetsmask = col_offsets < n_colsrow = tl.load(input_ptrs, mask=mask, other=-float('inf'))row_minus_max = row - tl.max(row, axis=0)numerator = tl.exp(row_minus_max)denominator = tl.sum(numerator, axis=0)softmax_output = numerator / denominatoroutput_row_start_ptr = output_ptr + row_idx * output_row_strideoutput_ptrs = output_row_start_ptr + col_offsetstl.store(output_ptrs, softmax_output, mask=mask)properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}def softmax(x):n_rows, n_cols = x.shapeBLOCK_SIZE = triton.next_power_of_2(n_cols)num_warps = 8num_stages = 4 if SIZE_SMEM > 200000 else 2y = torch.empty_like(x)kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,num_stages=num_stages, num_warps=num_warps, grid=(1, ))kernel._init_handles()n_regs = kernel.n_regssize_smem = kernel.metadata.sharedoccupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)occupancy = min(occupancy, SIZE_SMEM // size_smem)num_programs = NUM_SM * occupancynum_programs = min(num_programs, n_rows)kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)return ytorch.manual_seed(0)
x = torch.randn(64, 64, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)print('y_triton =')
print(y_triton[1:16, 1:16])

先看运行结果:

triton kernel 的逐行注释:

@triton.jit
def softmax_kernel(output_ptr,#输出矩阵指针input_ptr,#输入函数指针input_row_stride,#行主序的输入矩阵的 strideoutput_row_stride,#行主序的输出矩阵的 striden_rows,#矩阵的行数n_cols,#矩阵的列数BLOCK_SIZE: tl.constexpr,#每一行中含 n_cols 个有效元素,block_size 为能容纳下这么多元素的一块空间长度 len,同时 len 为 2 的整数次幂。num_stages: tl.constexpr):#一个指导流水线阶段数量意向值,后边展开说row_start = tl.program_id(0)#每个 triton 程序(类比 cuda 的 block) 每次迭代的过程中只负责矩阵一行数据的 softmax 计算。row_step = tl.num_programs(0)#总共能启动多少个 triton 程序,类比 cuda 的 block,也就是下一次迭代需要跨过的行数。#接下来这行,tl.range() 这个类似 C++ 的迭代器#其中特别需要指出的是 num_stages 这是一个意向值,在问题规模太大,一个就占据太多资源时,实际gpu 代码中的 stage 可能只有一个。for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):row_start_ptr = input_ptr + row_idx * input_row_stride#根据自己的额 row_idx 来找到输入矩阵的一行数据的起始地址col_offsets = tl.arange(0, BLOCK_SIZE)#程序中的 threadIdx.x 编号构成的一维数组,即 tensorinput_ptrs = row_start_ptr + col_offsets#本 thread 在矩阵中本行的实际取数地址mask = col_offsets < n_cols#边界检查用的 maskrow = tl.load(input_ptrs, mask=mask, other=-float('inf'))#实际加载数据,此mask 彼mask;row_minus_max = row - tl.max(row, axis=0)#这里的 tl.max() 将会引发 reduce 操作;然后本行每个元素都会减掉本行的最大元素numerator = tl.exp(row_minus_max)#计算 e^{x_j},即 softmax 中的 分子部分;denominator = tl.sum(numerator, axis=0)#把全部分子累加,作为 softmax 的分母softmax_output = numerator / denominator#计算新的元素值,即分子除以分母;output_row_start_ptr = output_ptr + row_idx * output_row_stride#计算回存数据行首在显存中的地址output_ptrs = output_row_start_ptr + col_offsets#计算本 thread 所需要存储的数据在显存中的具体地址,因线程不同而不同。tl.store(output_ptrs, softmax_output, mask=mask)#使用相同的掩码回存处理后的结果

2.1.  triton.next_power_of_2(n_cols) 的作用

返回 2^k,其中  k 为使得不等式    2^k \ge n\_col   成立的最小的正整数。

2.2.  _init_handles() 的作用

(1.)函数 _init_handles() 的作用

        在 Triton 的 JIT 编译框架中,_init_handles() 是一个内部方法,主要用于 初始化内核的底层执行句柄。_init_handles() 的主要职责是编译内核,在首次调用时,将 Triton 的 Python 代码编译为目标设备(如 GPU)的高效机器码(如 CUDA PTX)。生成优化后的计算图(DAG)和内存访问模式。分配运行时资源,为内核分配显存、流式多处理器(SM)资源等。绑定输入/输出张量的设备指针。并且做缓存管理,缓存编译后的内核二进制,避免重复编译(类似 PyTorch 的 torch.jit 缓存机制)。

(2.)设计成显式调用

        在 Triton 中,内核通常通过 kernel[grid](*args) 触发隐式编译和执行。但以下场景需手动调用 _init_handles():
预热(Warm-up):提前编译内核,避免首次运行时因编译延迟影响性能。
参数调优:在 warmup() 后调整执行配置(如 num_warps、num_stages),需重新初始化。
低延迟场景:确保内核在关键路径前已就绪。

2.3. 执行 triton kernel 

kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)

其他信息跟 cuda kernel 有明显的对应关系,这里借着这行代码,仅对 num_stages 多说一些。

首先,num_stages 是一个 意向值(hint),指导流水线阶段数量,实际生成的 GPU 代码中 Triton 编译器会根据硬件资源限制和问题规模进行优化调整,最终可能不会完全按照设定的 num_stages 生成机器码。

2.3.1.  num_stages 的原则和原理

num_stages 是编译时常量,但仅作为提示(hint):

        Triton 编译器会尝试按照 num_stages 进行流水线调度,但如果寄存器压力过大(每个 stage 需要额外的寄存器存储中间结果)或者共享内存/计算资源不足(例如 SM 上的线程块资源受限),以及当问题规模太小(如果 n_rows 很小,增加 num_stages 可能不会带来性能提升)编译器可能会自动降低 num_stages,甚至退化为 1(即无流水线)。

          更具体来说,每个 stage 在流水线中需要独立的寄存器组来存储中间状态。如果 BLOCK_SIZE 很大(例如处理长向量),每个线程需要更多的寄存器,可能导致编译器被迫减少 num_stages 以避免 register spilling(寄存器溢出到全局内存,严重降低性能)。例如,在 BLOCK_SIZE=1024 且 num_stages=4 时,编译器可能会发现寄存器不够用,从而最终生成 num_stages=1 的代码。

2.3.2.  验证实际的 num_stages

        Triton 提供了性能分析工具(如 triton.testing.do_bench),这样可以通过测量不同 num_stages 的性能来间接推断实际使用的阶段数。如果增加 num_stages 但性能没有提升,可能实际阶段数已被编译器优化降低。

2.3.3. 最佳实践

保守设置:通常 num_stages=3 或 4 是一个合理的起点(根据 NVIDIA GPU 的 SM 架构特性)。

资源敏感调整:如果 BLOCK_SIZE 较大,可能需要减少 num_stages。

动态适配:可以通过 triton.autotune 自动选择最优配置(Triton 内置支持自动调优)。

2.3.4.  示例场景分析

        假设有以下内核:

@triton.jit
def kernel(..., BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):for i in tl.range(0, n, num_stages=num_stages):...


如果设置 num_stages=4 但 BLOCK_SIZE=2048(每个线程需要大量寄存器),这时 triton 编译器可能实际生成 num_stages=1 的代码,因为寄存器不足。

       如果设置 num_stages=4 且 BLOCK_SIZE=128,这时编译器可能会成功生成 4 阶段流水线代码,充分利用指令级并行(ILP)。

       总而言之,num_stages 是一个建议性的目标值,实际执行时 Triton 编译器会根据资源约束自动优化。理解这一点对性能调优非常重要,尤其是在处理不同规模的问题时。

http://www.xdnf.cn/news/15656.html

相关文章:

  • 生成式引擎优化(GEO)核心解析:下一代搜索技术的演进与落地策略
  • Python包发布与分发全指南:从PyPI到企业私有仓库
  • LiteCloud超轻量级网盘项目基于Spring Boot
  • Solr7升级Solr8全攻略:从Core重命名到IK分词兼容,零业务中断实战指南
  • css样式中的选择器和盒子模型
  • 《汇编语言:基于X86处理器》第8章 高级过程(2)
  • QT跨平台应用程序开发框架(10)—— Qt窗口
  • PyCharm 高效入门指南(引言 + 核心模块详解)
  • C++拷贝构造
  • 【数据结构】栈和队列
  • 李宏毅《生成式人工智能导论》 | 第15讲-第18讲:生成的策略-影像有关的生成式AI
  • 【读论文】AgentOrchestra 解读:LLM 智能体学会「团队协作」去解决复杂任务
  • 河南萌新联赛2025第一场-河南工业大学
  • Python--plist文件的读取
  • 【Linux】LVS(Linux virual server)
  • python-字典、集合、序列切片、字符串操作(笔记)
  • 大型语言模型的白日梦循环
  • Git简介与特点:从Linux到分布式版本控制的革命
  • Python 网络爬虫 —— 代理服务器
  • github不能访问怎么办
  • echart设置trigger: ‘axis‘不显示hover效果
  • C 语言基础第 08 天:数组与冒泡排序
  • HTTPS的工作原理及DNS的工作过程
  • 相位中心偏置天线的SAR动目标检测
  • 基于Echarts的气象数据可视化网站系统的设计与实现(Python版)
  • 【LeetCode 热题 100】108. 将有序数组转换为二叉搜索树
  • Git 多人协作实战:从基础操作到分支管理全流程记录
  • 深入了解linux系统—— 信号的捕捉
  • 如何将 ONLYOFFICE 文档集成到使用 Laravel 框架编写的 PHP 网络应用程序中
  • Nginx/OpenResty HTTP 请求处理阶段与 Lua 实践全解20250717