FlashAttention:突破Transformer内存瓶颈的IO感知革命
FlashAttention:突破Transformer内存瓶颈的IO感知革命
当Transformer模型处理4096个token的序列时,标准注意力机制需要消耗67GB内存存储中间矩阵——这个数字足以让最先进的GPU崩溃。FlashAttention通过重新思考内存访问模式,将这一数字降低到仅需原始内存的1/10,同时保持数学等价性。
引言:注意力机制的内存困境
2017年,Vaswani等人提出的Transformer架构彻底改变了自然语言处理领域。其核心自注意力机制(Self-Attention Mechanism)使模型能够同时处理序列中的所有位置,并动态计算它们之间的相关性。标准注意力计算遵循以下公式:
S=QK⊤∈RN×N,P=softmax(S/dk),O=PV\mathbf{S} = \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{N \times N}, \quad \mathbf{P} = \mathrm{softmax}(\mathbf{S}/\sqrt{d_k}), \quad \mathbf{O} = \mathbf{P}\mathbf{V}S=QK⊤∈RN×N,P=softmax(S/dk),O=PV
其中Q\mathbf{Q}Q、K\mathbf{K}K、V\mathbf{V}V分别表示查询、键和值矩阵,NNN为序列长度,dkd_kdk为键向量的维度。缩放因子1/dk1/\sqrt{d_k}1/dk用于防止点积值过大导致softmax梯度消失。
然而,这种实现存在根本性缺陷:需要显式存储中间矩阵S\mathbf{S}S和P\mathbf{P}P到高带宽内存(HBM),导致O(N2)O(N^2)O(N2)的内存需求和O(Nd+N2)O(Nd + N^2)O(Nd+N2)次HBM访问。当序列长度NNN增大时(如达到1k-8k),内存访问成为主要性能瓶颈,限制了模型处理长序列的能力。
FlashAttention核心机制:IO感知优化
分块计算策略(Tiling)
FlashAttention的核心创新在于利用GPU内存层次结构进行IO感知优化。现代GPU具有多层内存结构:
- HBM:容量大(40-80GB),但带宽相对较低(1.5-2.0TB/s)
- SRAM:容量小(每流多处理器192KB),但带宽极高(19TB/s)
FlashAttention通过分块策略将计算分解为适合SRAM的小块:
Q=[Q1⋮QTr],K=[K1⋮KTc],V=[V1⋮VTc]\mathbf{Q} = \begin{bmatrix}\mathbf{Q}_1\\\vdots\\\mathbf{Q}_{T_r}\end{bmatrix}, \quad \mathbf{K} = \begin{bmatrix}\mathbf{K}_1\\\vdots\\\mathbf{K}_{T_c}\end{bmatrix}, \quad \mathbf{V} = \begin{bmatrix}\mathbf{V}_1\\\vdots\\\mathbf{V}_{T_c}\end{bmatrix}Q= Q1⋮QTr ,K= K1⋮