Scaled Dot-Product Attention 中的缩放操作
最近看代码又看到了一种缩放操作,记得之前了解过Transformer 中的缩放,但是细节记不清了,这里记录一下方便以后查阅。
Scaled Dot-Product Attention 中的缩放操作
- 1. Scaled Dot-Product Attention 机制简介
- 2. 为什么要控制注意力分数的方差?
- 2.1 未经缩放的点积方差
- 2.2 缩放后的方差
- 2.3 为什么选择 d k \sqrt{d_k} dk?
- 3. 总结
1. Scaled Dot-Product Attention 机制简介
在 Transformer 的多头注意力中,Scaled Dot-Product Attention 的计算公式为:
Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQK⊤)V
其中:
- Q , K ∈ R N × d k Q, K \in \mathbb{R}^{N \times d_k} Q,K∈RN×dk:查询(query)和键(key)张量, N N N 是序列长度, d k d_k dk 是每个头的维度(在你的代码中对应
head_dim
)。 - V ∈ R N × d v V \in \mathbb{R}^{N \times d_v} V∈RN×dv:值(value)张量, d v d_v dv 通常等于 d k d_k dk。
- Q K ⊤ ∈ R N × N QK^\top \in \mathbb{R}^{N \times N} QK⊤∈RN×N:注意力分数矩阵,表示每个查询与每个键的点积。
- d k \sqrt{d_k} dk:缩放因子,用于归一化点积。
- softmax \text{softmax} softmax:按行归一化注意力分数,使其和为 1。
缩放的目的:
缩放的核心目标是控制注意力分数 Q K ⊤ QK^\top QK⊤ 的方差,从而:
- 提高数值稳定性:避免 softmax 输入过大导致的溢出或退化问题。
- 提高训练稳定性:确保梯度分布合理,促进优化过程的收敛。
2. 为什么要控制注意力分数的方差?
注意力分数 Q K ⊤ QK^\top QK⊤ 的每个元素 s i j = q i ⊤ k j s_{ij} = \mathbf{q}_i^\top \mathbf{k}_j sij=qi⊤kj 是查询向量 q i ∈ R d k \mathbf{q}_i \in \mathbb{R}^{d_k} qi∈Rdk 和键向量 k j ∈ R d k \mathbf{k}_j \in \mathbb{R}^{d_k} kj∈Rdk 的点积:
s i j = ∑ k = 1 d k q i , k k j , k s_{ij} = \sum_{k=1}^{d_k} q_{i,k} k_{j,k} sij=k=1∑dkqi,kkj,k
在 Transformer 中,查询和键向量通常来自线性投影,其元素可以近似看作独立同分布的随机变量,均值为 0,方差为 1(例如,经过初始化的权重矩阵或归一化处理)。我们来分析点积的统计特性。
2.1 未经缩放的点积方差
假设 q i , k , k j , k ∼ N ( 0 , 1 ) q_{i,k}, k_{j,k} \sim \mathcal{N}(0, 1) qi,k,kj,k∼N(0,1)(均值 0,方差 1),则:
- 每个乘积项 q i , k k j , k q_{i,k} k_{j,k} qi,kkj,k 的均值为:
E [ q i , k k j , k ] = E [ q i , k ] ⋅ E [ k j , k ] = 0 ⋅ 0 = 0 \mathbb{E}[q_{i,k} k_{j,k}] = \mathbb{E}[q_{i,k}] \cdot \mathbb{E}[k_{j,k}] = 0 \cdot 0 = 0 E[qi,kkj,k]=E[qi,k]⋅E[kj,k]=0⋅0=0 - 方差为:
Var ( q i , k k j , k ) = E [ ( q i , k k j , k ) 2 ] = E [ q i , k 2 ] ⋅ E [ k j , k 2 ] = 1 ⋅ 1 = 1 \text{Var}(q_{i,k} k_{j,k}) = \mathbb{E}[(q_{i,k} k_{j,k})^2] = \mathbb{E}[q_{i,k}^2] \cdot \mathbb{E}[k_{j,k}^2] = 1 \cdot 1 = 1 Var(qi,kkj,k)=E[(qi,kkj,k)2]=E[qi,k2]⋅E[kj,k2]=1⋅1=1 - 点积 s i j s_{ij} sij 是 d k d_k dk 个独立项的和:
s i j = ∑ k = 1 d k q i , k k j , k s_{ij} = \sum_{k=1}^{d_k} q_{i,k} k_{j,k} sij=k=1∑dkqi,kkj,k
因此:- 均值: E [ s i j ] = ∑ k = 1 d k E [ q i , k k j , k ] = 0 \mathbb{E}[s_{ij}] = \sum_{k=1}^{d_k} \mathbb{E}[q_{i,k} k_{j,k}] = 0 E[sij]=∑k=1dkE[qi,kkj,k]=0
- 方差: Var ( s i j ) = ∑ k = 1 d k Var ( q i , k k j , k ) = d k \text{Var}(s_{ij}) = \sum_{k=1}^{d_k} \text{Var}(q_{i,k} k_{j,k}) = d_k Var(sij)=∑k=1dkVar(qi,kkj,k)=dk
问题:
- 当 d k d_k dk 较大(例如,64、128 或更高)时,点积 s i j s_{ij} sij 的方差为 d k d_k dk,意味着其值可能非常大(标准差为 d k \sqrt{d_k} dk)。
- 例如,若 d k = 64 d_k = 64 dk=64,则 Var ( s i j ) = 64 \text{Var}(s_{ij}) = 64 Var(sij)=64,标准差为 64 = 8 \sqrt{64} = 8 64=8,点积值可能在 ([-24, 24]) 或更广范围内波动(假设正态分布,约 99.7% 的值在 ± 3 σ \pm 3\sigma ±3σ 内)。
这种大方差会导致以下问题:
- Softmax 数值不稳定:
- Softmax 函数 softmax ( s i j ) = exp ( s i j ) ∑ j exp ( s i j ) \text{softmax}(s_{ij}) = \frac{\exp(s_{ij})}{\sum_j \exp(s_{ij})} softmax(sij)=∑jexp(sij)exp(sij) 对输入值非常敏感。
- 如果 s i j s_{ij} sij 的绝对值较大(例如 24),则 exp ( 24 ) ≈ 2.6 × 10 10 \exp(24) \approx 2.6 \times 10^{10} exp(24)≈2.6×1010,可能导致:
- 溢出:在 FP16 或 BF16 精度下,指数值可能超出数值范围。
- 退化:softmax 输出集中在最大值上(接近 one-hot),导致其他位置的权重接近 0,丢失信息。
- 训练不稳定:
- 大幅度的 s i j s_{ij} sij 会导致 softmax 后的梯度要么过大(爆炸),要么过小(消失),影响优化器的收敛。
- 梯度不稳定会使模型难以学习复杂的模式,尤其在深层网络中。
2.2 缩放后的方差
缩放操作通过除以 d k \sqrt{d_k} dk 来归一化点积:
s i j ′ = s i j d k = q i ⊤ k j d k s_{ij}' = \frac{s_{ij}}{\sqrt{d_k}} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_k}} sij′=dksij=dkqi⊤kj
- 均值:
E [ s i j ′ ] = E [ s i j ] d k = 0 \mathbb{E}[s_{ij}'] = \frac{\mathbb{E}[s_{ij}]}{\sqrt{d_k}} = 0 E[sij′]=dkE[sij]=0 - 方差:
Var ( s i j ′ ) = Var ( s i j ) d k = d k d k = 1 \text{Var}(s_{ij}') = \frac{\text{Var}(s_{ij})}{d_k} = \frac{d_k}{d_k} = 1 Var(sij′)=dkVar(sij)=dkdk=1 - 标准差:
Var ( s i j ′ ) = 1 = 1 \sqrt{\text{Var}(s_{ij}')} = \sqrt{1} = 1 Var(sij′)=1=1
效果:
- 缩放后,注意力分数 s i j ′ s_{ij}' sij′ 的方差固定为 1,标准差为 1,无论 d k d_k dk 的大小。
- 例如,若 d k = 64 d_k = 64 dk=64,未经缩放的 s i j s_{ij} sij 标准差为 8,缩放后标准差降为 1,值范围缩小到 ([-3, 3])(正态分布的 99.7% 范围)。
这使得:
- Softmax 输入的范围更合理,指数值 exp ( s i j ′ ) \exp(s_{ij}') exp(sij′) 不会轻易溢出。
- Softmax 输出更均匀,保留了不同键的相对权重,避免退化为 one-hot。
2.3 为什么选择 d k \sqrt{d_k} dk?
选择 d k \sqrt{d_k} dk 作为缩放因子是因为它精确地抵消了点积的方差增长:
- 点积的方差与 d k d_k dk 成正比,缩放因子 d k \sqrt{d_k} dk 将方差归一化为 1。
- 其他缩放因子(如 d k d_k dk 或常数)可能导致方差过小(抑制信号)或过大(仍不稳定)。
例如:
- 若缩放因子为 d k d_k dk,则 Var ( s i j ′ ) = d k d k 2 = 1 d k \text{Var}(s_{ij}') = \frac{d_k}{d_k^2} = \frac{1}{d_k} Var(sij′)=dk2dk=dk1,方差过小,信号被压缩。
- 若不缩放(因子为 1),则 Var ( s i j ′ ) = d k \text{Var}(s_{ij}') = d_k Var(sij′)=dk,方差过大,导致数值问题。
d k \sqrt{d_k} dk 是一个理论上和实践上平衡的选择,广泛应用于 Transformer 模型。
3. 总结
“缩放的目的是为了控制注意力分数的方差,从而提高数值稳定性和训练稳定性”在 Scaled Dot-Product Attention 机制中的具体含义是:
- 控制方差:点积 Q K ⊤ QK^\top QK⊤ 的方差与头维度 d k d_k dk 成正比( Var = d k \text{Var} = d_k Var=dk)。通过除以 d k \sqrt{d_k} dk,方差归一化为 1,注意力分数的范围缩小(例如 ([-3, 3])),适配 softmax 计算。
- 提高数值稳定性:缩放后的分数防止了 softmax 的溢出或退化(one-hot),尤其在 FP16/BF16 或高维度(如 d k = 128 d_k = 128 dk=128)场景下,减少了 NaN 或精度丢失的风险。
- 提高训练稳定性:稳定的注意力分数分布导致更均匀的 softmax 输出,梯度分布合理,避免了爆炸或消失,促进深层网络的收敛。
通过缩放,Scaled Dot-Product Attention 能够在多种场景下(长序列、深层网络、低精度训练)保持高效和稳定。