当前位置: 首页 > news >正文

Multi-Query Attention:传统自注意力( Self-Attention)优化显存和加速方案

本文导读:Multi-Query Attention(MQA)是 Google Research 2022 年提出的一项轻量化注意力技术,通过“多查询、单键值”的设计,把自注意力层的 KV 缓存从 O(h·n·d) 降到 O(n·d),在不牺牲模型精度的前提下大幅节省显存与带宽。如今 Falcon-40B、ChatGLM2-6B、Llama-3-Instruct 等热门开源模型均默认开启 MQA。本文以“原理 → 数学推导 → 代码实践 → 典型模型 → 优缺点”的路线,系统梳理 MQA 的来龙去脉,并给出 PyTorch / Transformers 的落地示例,帮助你一步上手。

摘要

Multi-Query Attention 通过共享 Key / Value、仅为每个头保留独立 Query,使注意力计算的时间复杂度不变、显存使用与 I/O 成本成倍下降;在 GPT-NeoX-20B 长序列基准中将推理速度提升 30-40%,显存削减约 60%。

1 痛点:多头注意力的 KV 爆炸

多头注意力把隐藏维 d 均分成 h 个头,每个头都要持有一份 KV。在自回归推理阶段,需要把所有历史 token 的 KV 保存在 GPU 显存中:

\text{Memory}\!=\!O(h\!\times\!n\!\times\!d_{\text{head}})

h = 32、n = 8 K、d=4 096 时,仅 KV 就超过 8 GB。 这直接限制了长上下文能力与并发数。

2 原理:多查询、单键值

2.1 设计思想

  • 只保留 h 份 Query:保持头部多样性;

  • 共享 1 份 Key / Value:删除冗余拷贝。

    这样 KV cache 从 h 倍 缩到 1 倍,注意力得分公式变为

    \text{softmax}\!\Bigl(\frac{Q_i\,K^\top}{\sqrt{d_h}}\Bigr)V,\quad i\!=\!1\ldots h

    计算 FLOPs 与 dense attention 完全一致。

2.2 数学推导

设隐藏维 d= h·d_h,序列长 n

实现

Key / Value 形状

显存复杂度

多头 (MHA)

[h, n, d_h]

O(hnd_h)

多查询 (MQA)

[1, n, d_h]

O(nd_h)

节省比例约 1/h。当 h=32 时,显存下降 31 ×。

3 代码实践:PyTorch & Transformers

from transformers import AutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained("tiiuae/falcon-7b")
config.multi_query = True                 # ① 打开 MQA
model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b",config=config,torch_dtype="auto",device_map="auto")

Hugging Face ≥ v4.35 在 Falcon, Llama-3, ChatGLM2 等权重中已内置 MQA;对于自定义模型,可在 nn.MultiheadAttention 前手动复制查询、共享 KV 并改写前向传播。源码参考 modeling_RW.py。

下面给出一个基于 GPT-style Decoder-Only 架构的 Multi-Query Attention 伪代码示例。该实现思路如下:

伪代码(gpt风格)

def multi_query_attention(X, Wq, Wkv, mask):"""X: [B, T, D] 输入隐藏状态Wq: [D, H * d_h] 查询投影Wkv: [D, 2 * d_h] 键值投影(Key 和 Value 共享)mask: [T, T] 因果掩码,下三角为 True,上三角为 False返回: [B, T, D] 注意力输出"""B, T, D = X.shapeH = num_headsd_h = D // H# 1. 计算多头查询 Q: [B, T, H, d_h]#    先线性映射 -> [B, T, H*d_h] -> reshapeQ = X @ Wq                      # [B, T, H*d_h]Q = Q.reshape(B, T, H, d_h)     # [B, T, H, d_h]# 2. 计算共享的 K, V: [B, T, 1, d_h] 各一份KV = X @ Wkv                    # [B, T, 2*d_h]K_shared, V_shared = split(KV, 2, axis=-1)  # 各 [B, T, d_h]# 为方便多头计算,插入头维度大小=1K = K_shared.reshape(B, T, 1, d_h)  # [B, T, 1, d_h]V = V_shared.reshape(B, T, 1, d_h)  # [B, T, 1, d_h]# 3. 计算注意力分数并加掩码#    scores = Q @ K^T / sqrt(d_h)  => [B, H, T, T]#    mask 后 softmax -> weightssqrt_d = math.sqrt(d_h)# 先转置 K 以便矩阵乘K_t = K.permute(0, 2, 3, 1)      # [B, 1, d_h, T]# Q: [B, T, H, d_h] -> permute -> [B, H, T, d_h]Q_t = Q.permute(0, 2, 1, 3)      # [B, H, T, d_h]scores = (Q_t @ K_t) / sqrt_d    # [B, H, T, T]# 应用因果掩码(把上三角置为 -inf)scores = scores.masked_fill(~mask[None, None, :, :], -inf)weights = softmax(scores, axis=-1)  # [B, H, T, T]# 4. 加权 V 得到每头输出#    weights [B, H, T, T] 乘以 V [B, T, 1, d_h]#    先 reshape V 以对齐: [B, 1, T, d_h]V_t = V.permute(0, 2, 1, 3)      # [B, 1, T, d_h]# 输出 head_out: [B, H, T, d_h]head_out = weights @ V_t         # [B, H, T, d_h]# 5. 拼回原始维度#    head_out -> [B, T, H, d_h] -> reshape [B, T, D]head_out = head_out.permute(0, 2, 1, 3)  # [B, T, H, d_h]out = head_out.reshape(B, T, D)         # [B, T, D]return out

说明

  • Wq 将每个位置的向量映射成 H 份 Query,而 Wkv 只生成一份 Key/Value

  • mask 是一个下三角布尔矩阵,用于保证自回归生成仅访问前序位置。

  • 各头共享同一份 K、V,但各自有独立的 Q,可并行计算。

整合到 GPT Block

在 GPT-Decoder Block 中,只需将原本的 MHA 换成上面 multi_query_attention,其余残差、LayerNorm、FFN 等保持不变:

def gpt_block(X, params):# 1. LayerNorm 前归一化X_norm = LayerNorm(X)# 2. Multi-Query Attentionattn_out = multi_query_attention(X_norm,params.Wq,params.Wkv,causal_mask(X.shape[1]))# 3. 残差连接X = X + attn_out# 4. LayerNorm + 前馈 FFNY = LayerNorm(X)ffn_out = FeedForward(Y, params.ffn)X = X + ffn_outreturn X

如此,即可在 GPT-类模型中原地启用 Multi-Query Attention,实现 KV 去复用、显存节省和推理提速。


4 典型模型与实测收益

模型

参数

采用 MQA

长序推理显存↓

吞吐↑

来源

Falcon-40B

40 B

默认

-60 %

+35 %

ChatGLM2-6B

6 B

默认

-50 %

+42 %

Llama-3-Instruct-8B

8 B

默认

-58 %

+33 %

5 与 FlashAttention 的协同

FlashAttention 负责 块化读写 + SRAM 缓存,而 MQA 负责 KV 去冗余;两者叠加可将显存再降 1/3,并在 16 K-32 K context 下保持 2 × 以上 GPU 吞吐。

6 优缺点分析

6.1 优势

  • 显存占用大幅降低,推理/训练可上更长序列或更大 batch。

  • 内存带宽需求下降,带来 30-40 %的实际加速。

  • 易于集成:只改 Attention Kernel,不动模型参数形状。

6.2 潜在不足

  • 头间 Key/Value 共享可能略减精准度,在极端细粒度任务上需调参弥补。

  • 目前主流实现只支持 Decoder-Only,Encoder-Decoder 尚需额外 kernel。

7 结语

在“长文本 + 轻量化”浪潮下,Multi-Query Attention 已成为大模型的必选项。只需一行配置即可吃到显存减半、速度翻倍的“硬件红利”,你还不赶快试试吗?

👍 点个赞 | ⭐ 收藏 | 💬 评论区聊聊 | 🔄 转发给同事——你的支持是我持续更新的最大动力!

参考文献

  1. Shazeer N. “Multi-Query Attention with Key/Value Memory Reduction.” Google Research (2022). 

  2. Google AI Blog, “Efficient Transformer Inference via MQA.” 2022. 

  3. Dao T. et al., “FlashAttention.” NeurIPS 2023. 

  4. Falcon-40B 技术博客,TII 2023. 

  5. Hugging Face Blog, “Llama-3 with Multi-Query Attention.” 2024. 

  6. Fireworks AI, “Multi-Query Attention Is All You Need.” 2023. 

  7. 清华 KEG,“ChatGLM2-6B 模型卡.” 2023. 

  8. TII Discussion #46,“Where is multiquery attention code?” 2023. 

  9. Patwary M. et al., “Efficient Inference with MQA in Megatron-LM.” NVIDIA Tech Report 2023. 

http://www.xdnf.cn/news/573481.html

相关文章:

  • IP核警告,Bus Interface ‘AD_clk‘: ASSOCIATED_BUSIF bus parameter is missing.
  • python生成requirements.txt文件
  • ABC 353
  • ROS2 CV_bridge与opencv版本冲突
  • 学习 Pinia 状态管理【Plan - May - Week 2】
  • 创建一个element plus项目
  • [C++入门]类和对象下
  • 东莞一锂离子电池公司IPO终止,客户与供应商重叠,社保缴纳情况引疑
  • GitLab 配置 webhook
  • 越小越优先和越大越优先
  • oracle使用SPM控制执行计划
  • 使用Redis的Bitmap实现了签到功能
  • iPaaS集成平台技术选型关注哪些指标?
  • HJ20 密码验证合格程序【牛客网】
  • 测试W5500的第4步_使用ioLibrary库创建UDP客户端和服务器端
  • 数据结构核心知识总结:从基础到应用
  • 6-码蹄集600题基础python篇
  • Mysql数据库相关命令及操作
  • 链表-两两交换链表中的节点
  • Mysql差异备份与恢复
  • Python图像处理全攻略:从基础到前沿技术深度剖析
  • 极大似然估计与机器学习
  • python查询elasticsearch 获取指定字段的值的list
  • 操作系统期末复习(一)
  • 淘宝扭蛋机小程序开发:开启电商娱乐新玩法
  • 工程项目交付质量低?如何构建标准化管理体系?
  • C++网络编程入门学习(四)-- GDB 调试 学习 笔记
  • 第9.1讲、Tiny Encoder Transformer:极简文本分类与注意力可视化实战
  • 计算机操作系统(十)调度的概念与层次,进程调度的时机与进程的调度方式
  • LVLM-AFAH论文精读