GQA(Grouped Query Attention):分组注意力机制的原理与实践《二》
🌟 Grouped Query Attention (GQA) 核心公式
GQA 的目标是减少 Key/Value
的数量,通过引入共享的 Key/Value 组,实现更高效的计算。
设:
- x ∈ R T × d x \in \mathbb{R}^{T \times d} x∈RT×d 是输入序列;
- h h h 为总头数, g g g 为 K/V 分组数(通常 g < h g < h g<h);
- 每组包含 h g = h g h_g = \frac{h}{g} hg=gh 个 Query;
- 每个头的维度为 d h = d h d_h = \frac{d}{h} dh=hd。
✅ Query/Key/Value 计算
Q = x W Q , K = x W K , V = x W V Q = xW^Q,\quad K = xW^K,\quad V = xW^V Q=xWQ,K=xWK,V=xWV
其中:
- Q ∈ R T × h × d h Q \in \mathbb{R}^{T \times h \times d_h} Q∈RT×h×dh
- K , V ∈ R T × g × d h K, V \in \mathbb{R}^{T \times g \times d_h} K,V∈RT×g×dh
说明:Q
拆分成 h h h 个头,K
和 V
只拆分成 g g g 个组,每 h g h_g hg 个 Q
共享一组 K/V
。
✅ 注意力权重计算(每个头 i i i)
令 j = ⌊ i h g ⌋ j = \left\lfloor \frac{i}{h_g} \right\rfloor j=⌊hgi⌋,表示第 i i i 个头对应第 j j j 个 K/V
组:
Attention i = softmax ( Q i K j ⊤ d h ) V j \text{Attention}_i = \text{softmax}\left( \frac{Q_i K_j^\top}{\sqrt{d_h}} \right)V_j Attentioni=softmax(dhQiKj⊤)Vj
✅ 最终输出拼接
Output = Concat ( Attention 1 , … , Attention h ) W O \text{Output} = \text{Concat}(\text{Attention}_1, \ldots, \text{Attention}_h)W^O Output=Concat(Attention1,…,Attentionh)WO
其中 W O W^O WO 为输出变换矩阵。
📌 总结
GQA 在保持多头 Query 精度的同时,大幅减少了 Key/Value 的计算和存储开销,适用于大规模模型(如 LLaMA 2/3、Qwen2 等)。