FlashAttention 公式推导
本文目前只介绍关于FlashAttention的公式推导,相关背景可参考:
- paper
- blog
一、分块下的softmax如何计算
对于向量 x ∈ R B x \in R^B x∈RB
x = [ x 1 , x 2 , . . . , x B ] m ( x ) : = max i x i = m a x ( x 1 , x 2 , . . . , x B ) f ( x ) : = [ e x 1 − m ( x ) , e x 2 − m ( x ) , . . . , e x B − m ( x ) ] l ( x ) : = ∑ i B f ( x ) i = e x 1 − m ( x ) + e x 2 − m ( x ) + . . . + e x B − m ( x ) s o f t m a x ( x ) : = f ( x ) l ( x ) = [ e x 1 − m ( x ) l ( x ) , e x 2 − m ( x ) l ( x ) , . . . , e x B − m ( x ) l ( x ) ] x=[x_1,x_2,...,x_B] \\ m(x) := \max_i x_i = max(x_1,x_2,...,x_B) \\ f(x) := [e^{x_1-m(x)}, e^{x_2-m(x)}, ..., e^{x_B-m(x)}] \\ l(x) := \sum_{i}^{B}f(x)_i=e^{x_1-m(x)} + e^{x_2-m(x)} + ... + e^{x_B-m(x)} \\ softmax(x) := \frac{f(x)}{l(x)}=[\frac{e^{x_1-m(x)}}{l(x)}, \frac{e^{x_2-m(x)}}{l(x)}, ..., \frac{e^{x_B-m(x)}}{l(x)}] x=[x1,x2,...,xB]m(x):=imaxxi=max(x1,x2,...,xB)f(x):=[ex1−m(x),ex2−m(x),...,exB−m(x)]l(x):=i∑Bf(x)i=ex1−m(x)+ex2−m(x)+...+exB−m(x)softmax(x):=l(x)f(x)=[l(x)ex1−m(x),l(x)ex2−m(x),...,l(x)exB−m(x)]
则对于向量 x ( 1 ) , x ( 2 ) ∈ R B x^{(1)}, x^{(2)} \in R^B x(1),x(2)∈RB, x = [ x ( 1 ) , x ( 2 ) ] ∈ R 2 B x=[x^{(1)},x^{(2)}] \in R^{2B} x=[x(1),x(2)]∈R2B
x ( 1 ) = [ x 1 ( 1 ) , x 2 ( 1 ) , . . . , x B ( 1 ) ] , x ( 2 ) = [ x 1 ( 2 ) , x 2 ( 2 ) , . . . , x B ( 2 ) ] m ( x ) = m ( [ x ( 1 ) , x ( 2 ) ] ) = m a x ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] l ( x ) = l ( [ x ( 1 ) , x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) s o f t m a x ( x ) = f ( x ) l ( x ) x^{(1)} = [x_1^{(1)},x_2^{(1)},...,x_B^{(1)}], x^{(2)}=[x_1^{(2)},x_2^{(2)},...,x_B^{(2)}]\\ m(x)=m([x^{(1)}, x^{(2)}])=max(m(x^{(1)}), m(x^{(2)})) \\ f(x)=[e^{m(x^{(1)}) - m(x)}f(x^{(1)}), e^{m(x^{(2)}) - m(x)}f(x^{(2)})]\\ l(x)=l([x^{(1)}, x^{(2)}])=e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})\\ softmax(x)=\frac{f(x)}{l(x)} x(1)=[x1(1),x2(1),...,xB(1)],x(2)=[x1(2),x2(2),...,xB(2)]m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2)))f(x)=[em(x(1))−m(x)f(x(1)),em(x(2))−m(x)f(x(2))]l(x)=l([x(1),x(2)])=em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))softmax(x)=l(x)f(x)
下面来推导 f ( x ) f(x) f(x)
f ( x ) = [ e x 1 ( 1 ) − m ( x ) , e x 2 ( 1 ) − m ( x ) , . . . , e x B ( 1 ) − m ( x ) , e x 1 ( 2 ) − m ( x ) , e x 2 ( 2 ) − m ( x ) , . . . , e x B ( 2 ) − m ( x ) ] → f ( x ) = [ e x 1 ( 1 ) − m ( x ( 1 ) ) + m ( x ( 1 ) ) − m ( x ) , e x 2 ( 1 ) − m ( x ( 1 ) ) + m ( x ( 1 ) ) − m ( x ) , . . . ] → f ( x ) = [ e x 1 ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m ( x ) , e x 2 ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m ( x ) , . . . ] → f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] f(x)=[e^{x_1^{(1)} - m(x)},e^{x_2^{(1)} - m(x)},...,e^{x_B^{(1)} - m(x)}, e^{x_1^{(2)} - m(x)},e^{x_2^{(2)} - m(x)},...,e^{x_B^{(2)} - m(x)}]\\ \to f(x)=[e^{x_1^{(1)} - m(x^{(1)}) + m(x^{(1)}) - m(x)}, e^{x_2^{(1)} -m(x^{(1)}) + m(x^{(1)}) - m(x)}, ...]\\ \to f(x)=[e^{x_1^{(1)} - m(x^{(1)})} * e^{m(x^{(1)}) - m(x)}, e^{x_2^{(1)} -m(x^{(1)})} * e^{m(x^{(1)}) - m(x)}, ...]\\ \to f(x)=[e^{m(x^{(1)}) - m(x)}f(x^{(1)}), e^{m(x^{(2)}) - m(x)}f(x^{(2)})] f(x)=[ex1(1)−m(x),ex2(1)−m(x),...,exB(1)−m(x),ex1(2)−m(x),ex2(2)−m(x),...,exB(2)−m(x)]→f(x)=[ex1(1)−m(x(1))+m(x(1))−m(x),ex2(1)−m(x(1))+m(x(1))−m(x),...]→f(x)=[ex1(1)−m(x(1))∗em(x(1))−m(x),ex2(1)−m(x(1))∗em(x(1))−m(x),...]→f(x)=[em(x(1))−m(x)f(x(1)),em(x(2))−m(x)f(x(2))]
下面来推导 l ( x ) l(x) l(x)
l ( x ) = ∑ i B e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) + ∑ i B e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) → l ( x ) = e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) l(x)=\sum_{i}^{B}e^{m(x^{(1)}) - m(x)}f(x^{(1)}) + \sum_{i}^{B}e^{m(x^{(2)}) - m(x)}f(x^{(2)})\\ \to l(x)=e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)}) l(x)=i∑Bem(x(1))−m(x)f(x(1))+i∑Bem(x(2))−m(x)f(x(2))→l(x)=em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))
则 s o f t m a x ( x ) = f ( x ) l ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) ] softmax(x)=\frac{f(x)}{l(x)}=[\frac{e^{m(x^{(1)}) - m(x)}f(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}, \frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}] softmax(x)=l(x)f(x)=[em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))em(x(1))−m(x)f(x(1)),em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))em(x(2))−m(x)f(x(2))]
有了当前的公式基础,我们可以开始FlashAttention的公式推导了
二、FlashAttention
下面是FlashAttention的算法描述:
下面我们逐行解释算法:
- 0、假设矩阵 Q 、 K 、 V ∈ R N × d Q、K、V \in R^{N \times d} Q、K、V∈RN×d位于HBM(GPU global memory),on-chip SRAM(GPU share memory)的内存大小为 M。
- 1、设置块大小为 B c = ⌈ M 4 d ⌉ , B r = m i n ( ⌈ M 4 d ⌉ , d ) B_c=\lceil \frac{M}{4d} \rceil,B_r=min(\lceil \frac{M}{4d} \rceil, d) Bc=⌈4dM⌉,Br=min(⌈4dM⌉,d)
- 2、初始化 [ N × d ] [N \times d] [N×d] 输出矩阵 O 全为0
初始化 N N N 维向量 l l l 全为0。存储 softmax 的累积分母——指数分数的总和
初始化 N N N 维向量 m m m 全为 − ∞ -\infty −∞。存储按行最大分数 - 3、使用步骤1中的块大小将 Q、K、V 分块。
Q 按 B r B_r Br 分块 Q 1 , . . . , Q T r Q_1,...,Q_{T_r} Q1,...,QTr,每个块的维度是 [ B r × d ] [B_r \times d] [Br×d], Q的块数为 T r = ⌈ N B r ⌉ T_r=\lceil \frac{N}{B_r} \rceil Tr=⌈BrN⌉。
K、V 按 B c B_c Bc 分块为 K 1 , . . . , K T c K_1,...,K_{T_c} K1,...,KTc 和 V 1 , . . . , V T c V_1,...,V_{T_c} V1,...,VTc,每个块的维度是 [ B c × d ] [B_c \times d] [Bc×d],K、V 的块数为 T c = ⌈ N B c ⌉ T_c=\lceil \frac{N}{B_c} \rceil Tc=⌈BcN⌉。 - 4、将O、l、m 按 B r B_r Br 分块。
O(矩阵) 分成 O 1 , . . . , O T r O_1,...,O_{T_r} O1,...,OTr,每个块大小为 [ B r × d ] [B_r \times d] [Br×d];
l(向量) 分成 l 1 , . . . , l T r l_1,...,l_{T_r} l1,...,lTr,每个块大小为 B r B_r Br
m(向量) 分成 m 1 , . . . , m T r m_1,...,m_{T_r} m1,...,mTr,每个块大小为 B r B_r Br - 5、outloop 遍历 $for 1 <= j <= T_c $,即遍历 Key/Value 向量
- 6、从 HBM(global memory) 加载 K i , V i K_i,V_i Ki,Vi 到 on-chip SRAM(share memory).由于我们构建块大小的方式,此时 SRAM 仍有至少 50%未被占用(用于 Q 和 O)。
- 7、innerloop 遍历 $for 1<= i <= T_r $,即对 Query 向量进行循环
- 8、从 HBM 加载 Q i , O i , l i , m i Q_i,O_i,l_i,m_i Qi,Oi,li,mi 到 on-chip SRAM。
- 9、计算 S i j = Q i K j T ∈ R B r × B c S_{ij}=Q_iK_j^T \in R^{B_r \times B_c} Sij=QiKjT∈RBr×Bc
- 10、使用上一步的 S i j S_{ij} Sij 计算 m i j , l i j , P i j m_{ij},l_{ij},P_{ij} mij,lij,Pij
m i j = r o w m a x ( S i j ) ∈ R B r P i j = e x p ( S i j − m i j ) ∈ R B r × B c l i j = r o w s u m ( p i j ) ∈ R B r m_{ij}=rowmax(S_{ij}) \in R^{B_r} \\ P_{ij}=exp(S_{ij} - m_{ij}) \in R^{B_r \times B_c}\\ l_{ij}=rowsum(p_{ij}) \in R^{B_r} mij=rowmax(Sij)∈RBrPij=exp(Sij−mij)∈RBr×Bclij=rowsum(pij)∈RBr - 11、计算 m i n e w = m a x ( m i , m i j ) l i n e w = e m i − m i n e w l i + e m i j − m i n e w l i j m_i^{new}=max(m_i, m_{ij}) \\ l_i^{new}=e^{m_i-m_i^{new}}l_i + e^{m_{ij-m_i^{new}}}l_{ij} minew=max(mi,mij)linew=emi−minewli+emij−minewlij
- 12、
W r i t e O i ← d i a g ( l i n e w ) − 1 ( d i a g ( l i ) e m i − m i n e w O i + e m i j − m i n e w P i j V j ) Write \ O_i \gets diag(l_i^{new})^{-1}(diag(l_i)e^{m_i-m_i^{new}}O_i + e^{m_{ij}-m_i^{new}}P_{ij}V_j) Write Oi←diag(linew)−1(diag(li)emi−minewOi+emij−minewPijVj)
上面的过程都很好理解,这里是最难理解的一步,我们来推导一下(我们不考虑 V i V_i Vi):
只第一个块时,第一个块的softmax输出:
O i ← s o f t m a x ( x ( 1 ) ) = f ( x ( 1 ) ) l ( x ( 1 ) ) 从而 O i = [ f ( x ( 1 ) ) l ( x ( 1 ) ) , 0 , 0 , 0 , . . . ] ( 请记住这里 O i 是向量,且其余块的值都为 0 ) O_i\gets softmax(x^{(1)})= \frac{f(x^{(1)})}{l(x^{(1)})} \\ 从而O_i=[\frac{f(x^{(1)})}{l(x^{(1)})}, 0, 0, 0,...] \\ (请记住这里 O_i 是向量,且其余块的值都为0) Oi←softmax(x(1))=l(x(1))f(x(1))从而Oi=[l(x(1))f(x(1)),0,0,0,...](请记住这里Oi是向量,且其余块的值都为0)
第一、二个块时,第一个块的softmax输出:
s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}f(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}\\ \to softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})} \\ \to softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}} softmax(x(1))=em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))em(x(1))−m(x)f(x(1))→softmax(x(1))=em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))em(x(1))−m(x)Oil(x(1))→softmax(x(1))=linewem(x(1))−m(x)Oil(i)
第二个块的softmax输出:
s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}\\ \to softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}} softmax(x(2))=em(x(1))−m(x)l(x(1))+em(x(2))−m(x)l(x(2))em(x(2))−m(x)f(x(2))→softmax(x(2))=linewem(x(2))−m(x)f(x(2))
则理论上 O i O_i Oi就应该是:
O i n e w = [ e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w , 0 , 0 , 0 , . . . ] O_{i_{new}}=[\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}}, \frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}},0,0, 0,...] Oinew=[linewem(x(1))−m(x)Oil(i),linewem(x(2))−m(x)f(x(2)),0,0,0,...]
这就是每个块softmax结果的递推过程了。有了这些,我们就可以拆解一下步骤12中的公式:
1、 d i a g ( l i n e w ) − 1 diag(l_i^{new})^{-1} diag(linew)−1 即将 l i n e w l_i^{new} linew 作为分母
2、 d i a g ( l i ) e m i − m i n e w O i diag(l_i)e^{m_i-m_i^{new}}O_i diag(li)emi−minewOi 不难发现是我们上面推导的 s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}} softmax(x(1))=linewem(x(1))−m(x)Oil(i) 的分子部分
3、而 e m i j − m i n e w P i j e^{m_{ij}-m_i^{new}}P_{ij} emij−minewPij 就是上面 s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}} softmax(x(2))=linewem(x(2))−m(x)f(x(2)) 的分子部分
- 13、
W r i t e l i ← l i n e w , m i ← m i n e w t o H B M Write \ l_i \gets l_i^{new}, m_i \gets m_i^{new} \ \ to \ \ \ HBM Write li←linew,mi←minew to HBM - 14、 end for innerloop
- 15、end for outloop
- 16、Return O