当前位置: 首页 > news >正文

【大模型LLM学习】Flash-Attention的学习记录

【大模型LLM学习】Flash-Attention的学习记录

  • 0. 前言
  • 1. flash-attention原理简述
  • 2. 从softmax到online softmax
    • 2.1 safe-softmax
    • 2.2 3-pass safe softmax
    • 2.3 Online softmax
    • 2.4 Flash-attention
    • 2.5 Flash-attention tiling

0. 前言

  Flash Attention可以节约模型训练和推理时间,很多模型可以通过config参数来选择attention是标准的attention实现还是flash-attention方式。在这里记录一下flash attention的学习过程,发现了一位博主以及参考的资料特别好:

  • zhihu一位做高性能计算的博主博文
  • 华盛顿大学的课程note

1. flash-attention原理简述

a t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V attention(Q,K,V)=softmax(dk QKT)V
  标准的attention操作的时间卡点不是在运算上,而是卡在数据读写上。SRAM的读写速度快,但是存储空间有限,无法一次存下来所有的中间计算结果,一次attention计算存在SRAM<->HBM的多次读写操作。
在这里插入图片描述
  与标准的attention操作比较,flash-attention通过减少数据在HBM和SRAM间的读写操作,来节约时间(甚至backward时还进行了重新计算,重新计算的速度也比把数据从HBM读取到SRAM要快)。
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention

2. 从softmax到online softmax

  直接看flash-attention的论文比较难看明白,发现华盛顿大学的那份note写得特别清晰,跟着它从softmax看到flash-attention会比较容易。

2.1 safe-softmax

  首先是safe的softmax计算方式。原始的softmax,对于N个数:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N softmax(\{x_1,...,x_N\})=\left\{\frac{e^{x_i}}{\sum_{j=1}^{N}e^{x_j}}\right\}_{i=1}^{N} softmax({x1,...,xN})={j=1Nexjexi}i=1N
  对于FP16,最大能表示的数据为65536,当 x > = 11 x>=11 x>=11时, e x e^x ex就会超过FP16的最大表示范围影响结果的正确性。为了避免这个问题,SafeSoftmax 通过减去输入向量中的最大值来调整输入,使得最大的指数项变为 e 0 = 1 e^0=1 e0=1从而防止了上溢的发生。同时,由于所有的指数项都除以同一个数,它们的比例关系不会改变,因此也不会影响最终的概率分布。
e x i ∑ j = 1 N e x j = e x i − m ∑ j = 1 N e x j − m , m = m a x { x j } j = 1 N \frac{e^{x_i}}{\sum_{j=1}{N}e^{x_j}}=\frac{e^{x_i-m}}{\sum_{j=1}{N}e^{x_j-m}}, \quad m=max\left\{x_j\right\}_{j=1}^{N} j=1Nexjexi=j=1Nexjmexim,m=max{xj}j=1N

2.2 3-pass safe softmax

  • 对于一个行向量 { x i } i = 1 N \{x_i\}_{i=1}^N {xi}i=1N,最直白的softmax计算方式是直接for循环

在这里插入图片描述
  这个算法计算softmax需要执行3次从1->N的循环,在attention中, { x i } \{x_i\} {xi} Q K T QK^T QKT的结果,但是如果SRAM里面存不下这个大的矩阵,上面的计算过程,就需要从HBM里面加载3次 { x i } \{x_i\} {xi},时间花在了数据读写上。

2.3 Online softmax

  如果能把上面(7)(8)(9)这3个式子的计算放一个for循环,就只需要一次load数据。但是 m N m_N mN是全局最大值,计算 m N m_N mN就已经需要一次遍历了。
  Online softmax算法把(7)(8)进行了合并,把3次遍历缩减为2个。它提出计算 d i ′ = ∑ j = 1 i e x j − m i d_i^{\prime}=\sum_{j=1}^{i}e^{x_j-m_i} di=j=1iexjmi来代替计算 d i d_i di,当算到最后 i = N i=N i=N时会发现, d N = d N ′ d_N=d_N^{\prime} dN=dN。具体的,迭代计算 d i ′ d_i^{\prime} di的方式为:
d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} d_i^{\prime} &= \sum_{j=1}^{i} e^{x_j - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_i} \right) + e^{x_i - m_i} \\ &= \left( \sum_{j=1}^{i-1} e^{x_j - m_{i-1}} \right) e^{m_{i-1} - m_i} + e^{x_i - m_i} \\ &= d_{i-1}^{\prime} e^{m_{i-1} - m_i} + e^{x_i - m_i} \end{aligned} di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi

  所以就可以用迭代的方式,在找最大值 m N m_N mN的时候,同时来计算 d i ′ d_i^{\prime} di,把(7)和(8)一起计算,这样只需要加载两次 x i x_i xi

在这里插入图片描述

2.4 Flash-attention

  上面的online softmax仍然需要2个for循环,加载2次 x i x_i xi来完成softmax的计算。完成softmax的计算,没法更进一步地压缩到1次遍历。但是attention计算的最终目标是获取输出结果,也就是注意力分数与 V V V相乘的结果 O = A × V O=A \times V O=A×V,计算 O O O可以通过一次遍历完成。
在这里插入图片描述
  可以使用类似online softmax把计算 d i d_i di变成计算 d i ′ d_i^{\prime} di的方式,把 o i o_i oi的计算也改成迭代式的,首先把 a i a_i ai带入 o i o_i oi的表达式
o i = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) o_i=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{N}}}{d_N^{\prime}}V[j,:]\right) oi=j=1i(dNexjmNV[j,:])

  可以找到一个 o i ′ o_i^{\prime} oi,它不依赖于全局的 d N ′ d_N^{\prime} dN m N m_N mN
o i ′ = ∑ j = 1 i ( e x j − m i d i ′ V [ j , : ] ) o_i^{\prime}=\sum_{j=1}^{i}\left(\frac{e^{x_j-m_{i}}}{d_i^{\prime}}V[j,:]\right) oi=j=1i(diexjmiV[j,:])

  对于 o i ′ o_i^{\prime} oi的计算可以使用迭代的方式,同样的是有 o N = o N ′ o_N=o_N^{\prime} oN=oN
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i , : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ V [ j , : ] ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i , : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] \begin{aligned} o_i' &= \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_i}}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} \frac{e^{x_j - m_i}}{e^{x_j - m_{i-1}}} \frac{d_{i-1}'}{d_i'} V[j,:] \right) + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= \left( \sum_{j=1}^{i-1} \frac{e^{x_j - m_{i-1}}}{d_{i-1}'} V[j,:] \right) \frac{d_{i-1}'}{d_i'} e^{m_{i-1} - m_i} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \\ &= o_{i-1}' \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i,:] \end{aligned} oi=j=1idiexjmiV[j,:]=(j=1i1diexjmiV[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1exjmi1exjmididi1V[j,:])+dieximiV[i,:]=(j=1i1di1exjmi1V[j,:])didi1emi1mi+dieximiV[i,:]=oi1didi1emi1mi+dieximiV[i,:]

  这样计算attention的输出结果可以只进行一次遍历就完成
在这里插入图片描述

2.5 Flash-attention tiling

  上面是每次计算一个元素 [ i ] [i] [i],实际上可以一次读取一个大小为b的块(tile)来计算

在这里插入图片描述在这里插入图片描述

  此外,在flash-attention的paper里面,对 Q Q Q K K K V V V O O O分块,其中 Q Q Q
O O O每块大小为 m i n ( M / 4 d , d ) × d min(M/4d,d) \times d min(M/4d,d)×d K / V K/V K/V的每块大小为 M / 4 d × d M/4d \times d M/4d×d,加起来正好不会超过SRAM的大小M,完整的算法在paper中:
在这里插入图片描述

http://www.xdnf.cn/news/903403.html

相关文章:

  • 阿里140 补环境日志
  • 华为 “一底双长焦” 专利公布,引领移动影像新变革
  • Caliper 负载(Workload)详细解析
  • 【NLP中向量化方式】序号化,亚编码,词袋法等
  • MySQL数据库基础(二)———数据表管理
  • 安卓基础(生成APK)
  • React 第五十六节 Router 中useSubmit的使用详解及注意事项
  • next,react封装axios,http请求
  • ✅ 常用 Java HTTP 客户端汇总及使用示例
  • 【零基础 快速学Java】韩顺平 零基础30天学会Java[学习笔记]
  • HTTP 请求协议简单介绍
  • 2025年SEVC SCI2区,潜力驱动多学习粒子群算法PDML-PSO,深度解析+性能实测
  • MySQL查询语句(续)
  • uniapp Vue2 获取电量的独家方法:绕过官方插件限制
  • Amazon Bedrock 助力 SolveX.AI 构建智能解题 Agent,打造头部教育科技应用
  • 当丰收季遇上超导磁测量:粮食产业的科技新征程
  • 智能手表健康监测系统的PSRAM存储芯片CSS6404LS-LI—高带宽、耐高温、微尺寸的三重突破
  • 微算法科技(NASDAQ:MLGO)基于信任的集成共识和灰狼优化(GWO)算法,搭建高信任水平的区块链网络
  • Guava LoadingCache 使用指南
  • Web前端基础:HTML-CSS
  • D3ctf-web-d3invitation单题wp
  • Q: dify前端使用哪些开发框架?
  • Houdini POP入门学习05 - 物理属性
  • 无头浏览器技术:Python爬虫如何精准模拟搜索点击
  • 每日八股文6.6
  • PowerBI企业运营分析—列互换式中国式报表分析
  • 【应用】Ghost Dance:利用惯性动捕构建虚拟舞伴
  • 单片机内部结构基础知识 FLASH相关解读
  • 数据集-目标检测系列- 口红嘴唇 数据集 lips >> DataBall
  • windows10搭建nfs服务器