当前位置: 首页 > news >正文

FlashAttention 公式推导

本文目前只介绍关于FlashAttention的公式推导,相关背景可参考:

  • paper
  • blog

一、分块下的softmax如何计算

对于向量 x ∈ R B x \in R^B xRB
x = [ x 1 , x 2 , . . . , x B ] m ( x ) : = max ⁡ i x i = m a x ( x 1 , x 2 , . . . , x B ) f ( x ) : = [ e x 1 − m ( x ) , e x 2 − m ( x ) , . . . , e x B − m ( x ) ] l ( x ) : = ∑ i B f ( x ) i = e x 1 − m ( x ) + e x 2 − m ( x ) + . . . + e x B − m ( x ) s o f t m a x ( x ) : = f ( x ) l ( x ) = [ e x 1 − m ( x ) l ( x ) , e x 2 − m ( x ) l ( x ) , . . . , e x B − m ( x ) l ( x ) ] x=[x_1,x_2,...,x_B] \\ m(x) := \max_i x_i = max(x_1,x_2,...,x_B) \\ f(x) := [e^{x_1-m(x)}, e^{x_2-m(x)}, ..., e^{x_B-m(x)}] \\ l(x) := \sum_{i}^{B}f(x)_i=e^{x_1-m(x)} + e^{x_2-m(x)} + ... + e^{x_B-m(x)} \\ softmax(x) := \frac{f(x)}{l(x)}=[\frac{e^{x_1-m(x)}}{l(x)}, \frac{e^{x_2-m(x)}}{l(x)}, ..., \frac{e^{x_B-m(x)}}{l(x)}] x=[x1,x2,...,xB]m(x):=imaxxi=max(x1,x2,...,xB)f(x):=[ex1m(x),ex2m(x),...,exBm(x)]l(x):=iBf(x)i=ex1m(x)+ex2m(x)+...+exBm(x)softmax(x):=l(x)f(x)=[l(x)ex1m(x),l(x)ex2m(x),...,l(x)exBm(x)]
则对于向量 x ( 1 ) , x ( 2 ) ∈ R B x^{(1)}, x^{(2)} \in R^B x(1),x(2)RB, x = [ x ( 1 ) , x ( 2 ) ] ∈ R 2 B x=[x^{(1)},x^{(2)}] \in R^{2B} x=[x(1),x(2)]R2B
x ( 1 ) = [ x 1 ( 1 ) , x 2 ( 1 ) , . . . , x B ( 1 ) ] , x ( 2 ) = [ x 1 ( 2 ) , x 2 ( 2 ) , . . . , x B ( 2 ) ] m ( x ) = m ( [ x ( 1 ) , x ( 2 ) ] ) = m a x ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] l ( x ) = l ( [ x ( 1 ) , x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) s o f t m a x ( x ) = f ( x ) l ( x ) x^{(1)} = [x_1^{(1)},x_2^{(1)},...,x_B^{(1)}], x^{(2)}=[x_1^{(2)},x_2^{(2)},...,x_B^{(2)}]\\ m(x)=m([x^{(1)}, x^{(2)}])=max(m(x^{(1)}), m(x^{(2)})) \\ f(x)=[e^{m(x^{(1)}) - m(x)}f(x^{(1)}), e^{m(x^{(2)}) - m(x)}f(x^{(2)})]\\ l(x)=l([x^{(1)}, x^{(2)}])=e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})\\ softmax(x)=\frac{f(x)}{l(x)} x(1)=[x1(1),x2(1),...,xB(1)],x(2)=[x1(2),x2(2),...,xB(2)]m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2)))f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))]l(x)=l([x(1),x(2)])=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))softmax(x)=l(x)f(x)
下面来推导 f ( x ) f(x) f(x)
f ( x ) = [ e x 1 ( 1 ) − m ( x ) , e x 2 ( 1 ) − m ( x ) , . . . , e x B ( 1 ) − m ( x ) , e x 1 ( 2 ) − m ( x ) , e x 2 ( 2 ) − m ( x ) , . . . , e x B ( 2 ) − m ( x ) ] → f ( x ) = [ e x 1 ( 1 ) − m ( x ( 1 ) ) + m ( x ( 1 ) ) − m ( x ) , e x 2 ( 1 ) − m ( x ( 1 ) ) + m ( x ( 1 ) ) − m ( x ) , . . . ] → f ( x ) = [ e x 1 ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m ( x ) , e x 2 ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m ( x ) , . . . ] → f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] f(x)=[e^{x_1^{(1)} - m(x)},e^{x_2^{(1)} - m(x)},...,e^{x_B^{(1)} - m(x)}, e^{x_1^{(2)} - m(x)},e^{x_2^{(2)} - m(x)},...,e^{x_B^{(2)} - m(x)}]\\ \to f(x)=[e^{x_1^{(1)} - m(x^{(1)}) + m(x^{(1)}) - m(x)}, e^{x_2^{(1)} -m(x^{(1)}) + m(x^{(1)}) - m(x)}, ...]\\ \to f(x)=[e^{x_1^{(1)} - m(x^{(1)})} * e^{m(x^{(1)}) - m(x)}, e^{x_2^{(1)} -m(x^{(1)})} * e^{m(x^{(1)}) - m(x)}, ...]\\ \to f(x)=[e^{m(x^{(1)}) - m(x)}f(x^{(1)}), e^{m(x^{(2)}) - m(x)}f(x^{(2)})] f(x)=[ex1(1)m(x),ex2(1)m(x),...,exB(1)m(x),ex1(2)m(x),ex2(2)m(x),...,exB(2)m(x)]f(x)=[ex1(1)m(x(1))+m(x(1))m(x),ex2(1)m(x(1))+m(x(1))m(x),...]f(x)=[ex1(1)m(x(1))em(x(1))m(x),ex2(1)m(x(1))em(x(1))m(x),...]f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))]
下面来推导 l ( x ) l(x) l(x)
l ( x ) = ∑ i B e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) + ∑ i B e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) → l ( x ) = e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) l(x)=\sum_{i}^{B}e^{m(x^{(1)}) - m(x)}f(x^{(1)}) + \sum_{i}^{B}e^{m(x^{(2)}) - m(x)}f(x^{(2)})\\ \to l(x)=e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)}) l(x)=iBem(x(1))m(x)f(x(1))+iBem(x(2))m(x)f(x(2))l(x)=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))
s o f t m a x ( x ) = f ( x ) l ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) ] softmax(x)=\frac{f(x)}{l(x)}=[\frac{e^{m(x^{(1)}) - m(x)}f(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}, \frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}] softmax(x)=l(x)f(x)=[em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))em(x(1))m(x)f(x(1)),em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))em(x(2))m(x)f(x(2))]
有了当前的公式基础,我们可以开始FlashAttention的公式推导了

二、FlashAttention

下面是FlashAttention的算法描述:
在这里插入图片描述
下面我们逐行解释算法:

  • 0、假设矩阵 Q 、 K 、 V ∈ R N × d Q、K、V \in R^{N \times d} QKVRN×d位于HBM(GPU global memory),on-chip SRAM(GPU share memory)的内存大小为 M。
  • 1、设置块大小为 B c = ⌈ M 4 d ⌉ , B r = m i n ( ⌈ M 4 d ⌉ , d ) B_c=\lceil \frac{M}{4d} \rceil,B_r=min(\lceil \frac{M}{4d} \rceil, d) Bc=4dMBr=min(⌈4dM,d)
  • 2、初始化 [ N × d ] [N \times d] [N×d] 输出矩阵 O 全为0
          初始化 N N N 维向量 l l l 全为0。存储 softmax 的累积分母——指数分数的总和
          初始化 N N N 维向量 m m m 全为 − ∞ -\infty 。存储按行最大分数
  • 3、使用步骤1中的块大小将 Q、K、V 分块。
          Q 按 B r B_r Br 分块 Q 1 , . . . , Q T r Q_1,...,Q_{T_r} Q1,...,QTr,每个块的维度是 [ B r × d ] [B_r \times d] [Br×d], Q的块数为 T r = ⌈ N B r ⌉ T_r=\lceil \frac{N}{B_r} \rceil Tr=BrN
          K、V 按 B c B_c Bc 分块为 K 1 , . . . , K T c K_1,...,K_{T_c} K1,...,KTc V 1 , . . . , V T c V_1,...,V_{T_c} V1,...,VTc,每个块的维度是 [ B c × d ] [B_c \times d] [Bc×d],K、V 的块数为 T c = ⌈ N B c ⌉ T_c=\lceil \frac{N}{B_c} \rceil Tc=BcN
  • 4、将O、l、m 按 B r B_r Br 分块。
          O(矩阵) 分成 O 1 , . . . , O T r O_1,...,O_{T_r} O1,...,OTr,每个块大小为 [ B r × d ] [B_r \times d] [Br×d]
          l(向量) 分成 l 1 , . . . , l T r l_1,...,l_{T_r} l1,...,lTr,每个块大小为 B r B_r Br
          m(向量) 分成 m 1 , . . . , m T r m_1,...,m_{T_r} m1,...,mTr,每个块大小为 B r B_r Br
  • 5、outloop 遍历 $for 1 <= j <= T_c $,即遍历 Key/Value 向量
  • 6、从 HBM(global memory) 加载 K i , V i K_i,V_i Ki,Vi 到 on-chip SRAM(share memory).由于我们构建块大小的方式,此时 SRAM 仍有至少 50%未被占用(用于 Q 和 O)。
  • 7、innerloop 遍历 $for 1<= i <= T_r $,即对 Query 向量进行循环
  • 8、从 HBM 加载 Q i , O i , l i , m i Q_i,O_i,l_i,m_i Qi,Oi,li,mi 到 on-chip SRAM。
  • 9、计算 S i j = Q i K j T ∈ R B r × B c S_{ij}=Q_iK_j^T \in R^{B_r \times B_c} Sij=QiKjTRBr×Bc
  • 10、使用上一步的 S i j S_{ij} Sij 计算 m i j , l i j , P i j m_{ij},l_{ij},P_{ij} mij,lij,Pij
    m i j = r o w m a x ( S i j ) ∈ R B r P i j = e x p ( S i j − m i j ) ∈ R B r × B c l i j = r o w s u m ( p i j ) ∈ R B r m_{ij}=rowmax(S_{ij}) \in R^{B_r} \\ P_{ij}=exp(S_{ij} - m_{ij}) \in R^{B_r \times B_c}\\ l_{ij}=rowsum(p_{ij}) \in R^{B_r} mij=rowmax(Sij)RBrPij=exp(Sijmij)RBr×Bclij=rowsum(pij)RBr
  • 11、计算 m i n e w = m a x ( m i , m i j ) l i n e w = e m i − m i n e w l i + e m i j − m i n e w l i j m_i^{new}=max(m_i, m_{ij}) \\ l_i^{new}=e^{m_i-m_i^{new}}l_i + e^{m_{ij-m_i^{new}}}l_{ij} minew=max(mi,mij)linew=emiminewli+emijminewlij
  • 12、
    W r i t e O i ← d i a g ( l i n e w ) − 1 ( d i a g ( l i ) e m i − m i n e w O i + e m i j − m i n e w P i j V j ) Write \ O_i \gets diag(l_i^{new})^{-1}(diag(l_i)e^{m_i-m_i^{new}}O_i + e^{m_{ij}-m_i^{new}}P_{ij}V_j) Write Oidiag(linew)1(diag(li)emiminewOi+emijminewPijVj)
    上面的过程都很好理解,这里是最难理解的一步,我们来推导一下(我们不考虑 V i V_i Vi):
    只第一个块时,第一个块的softmax输出:
    O i ← s o f t m a x ( x ( 1 ) ) = f ( x ( 1 ) ) l ( x ( 1 ) ) 从而 O i = [ f ( x ( 1 ) ) l ( x ( 1 ) ) , 0 , 0 , 0 , . . . ] ( 请记住这里 O i 是向量,且其余块的值都为 0 ) O_i\gets softmax(x^{(1)})= \frac{f(x^{(1)})}{l(x^{(1)})} \\ 从而O_i=[\frac{f(x^{(1)})}{l(x^{(1)})}, 0, 0, 0,...] \\ (请记住这里 O_i 是向量,且其余块的值都为0) Oisoftmax(x(1))=l(x(1))f(x(1))从而Oi=[l(x(1))f(x(1)),0,0,0,...](请记住这里Oi是向量,且其余块的值都为0)
    第一、二个块时,第一个块的softmax输出:
    s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}f(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}\\ \to softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(x^{(1)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})} \\ \to softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}} softmax(x(1))=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))em(x(1))m(x)f(x(1))softmax(x(1))=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))em(x(1))m(x)Oil(x(1))softmax(x(1))=linewem(x(1))m(x)Oil(i)
    第二个块的softmax输出:
    s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) → s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{e^{m(x^{(1)}) - m(x)}l(x^{(1)})+ e^{m(x^{(2)}) - m(x)}l(x^{(2)})}\\ \to softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}} softmax(x(2))=em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))em(x(2))m(x)f(x(2))softmax(x(2))=linewem(x(2))m(x)f(x(2))
    则理论上 O i O_i Oi就应该是:
    O i n e w = [ e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w , 0 , 0 , 0 , . . . ] O_{i_{new}}=[\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}}, \frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}},0,0, 0,...] Oinew=[linewem(x(1))m(x)Oil(i),linewem(x(2))m(x)f(x(2)),0,0,0,...]

这就是每个块softmax结果的递推过程了。有了这些,我们就可以拆解一下步骤12中的公式:
1、 d i a g ( l i n e w ) − 1 diag(l_i^{new})^{-1} diag(linew)1 即将 l i n e w l_i^{new} linew 作为分母
2、 d i a g ( l i ) e m i − m i n e w O i diag(l_i)e^{m_i-m_i^{new}}O_i diag(li)emiminewOi 不难发现是我们上面推导的 s o f t m a x ( x ( 1 ) ) = e m ( x ( 1 ) ) − m ( x ) O i l ( i ) l i n e w softmax(x^{(1)})=\frac{e^{m(x^{(1)}) - m(x)}O_il(_i)}{l_i^{new}} softmax(x(1))=linewem(x(1))m(x)Oil(i) 的分子部分
3、而 e m i j − m i n e w P i j e^{m_{ij}-m_i^{new}}P_{ij} emijminewPij 就是上面 s o f t m a x ( x ( 2 ) ) = e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) l i n e w softmax(x^{(2)})=\frac{e^{m(x^{(2)}) - m(x)}f(x^{(2)})}{l_i^{new}} softmax(x(2))=linewem(x(2))m(x)f(x(2)) 的分子部分

  • 13、
    W r i t e l i ← l i n e w , m i ← m i n e w t o H B M Write \ l_i \gets l_i^{new}, m_i \gets m_i^{new} \ \ to \ \ \ HBM Write lilinew,miminew  to   HBM
  • 14、 end for innerloop
  • 15、end for outloop
  • 16、Return O
http://www.xdnf.cn/news/922339.html

相关文章:

  • [AI绘画]sd学习记录(二)文生图参数进阶
  • Rapidio门铃消息FIFO溢出机制
  • TongWeb7.0动态密钥说明
  • 实战:子组件获取父组件订单信息
  • 【学习笔记】如何给软件加数字签名
  • 在 Windows 11 或 10 上将 Git 升级到最新版本的方法
  • 【Linux】LInux下第一个程序:进度条
  • 十一、【ESP32开发全栈指南: TCP通信服务端】
  • 1-3 Linux-虚拟机(2025.6.7学习篇- mac版本)
  • Sentry 接口返回 Status Code 429 Too Many Requests
  • 【优选算法】C++滑动窗口
  • 在ubuntu等linux系统上申请https证书
  • Redis内存淘汰策略
  • redis集群
  • [最全总结]城市灾害应急管理系统
  • Linux虚拟化技术:从KVM到容器的轻量化革命
  • Nodejs工程化实践:构建高性能前后端交互系统
  • sqlsugar WhereIF条件的大于等于和等于查出来的坑
  • WSL文件如何上传到GitHub
  • python版若依框架开发:后端开发规范
  • 快捷键的记录
  • UOS无法安装deb软件包
  • [论文阅读] 人工智能 | 搜索增强LLMs的用户偏好与性能分析
  • AcWing--数据结构1
  • stm32—ADC和DAC
  • 《JavaAI:稳定、高效、跨平台的AI编程工具优势解析》
  • Linux下的fuser用法简析
  • 文件(保存)通讯录
  • 长跑赛接力赛模式
  • C++ -- 多态