当前位置: 首页 > news >正文

Scaled Dot-Product Attention 中的缩放操作

最近看代码又看到了一种缩放操作,记得之前了解过Transformer 中的缩放,但是细节记不清了,这里记录一下方便以后查阅。

Scaled Dot-Product Attention 中的缩放操作

1. Scaled Dot-Product Attention 机制简介

在 Transformer 的多头注意力中,Scaled Dot-Product Attention 的计算公式为:
Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QK)V
其中:

  • Q , K ∈ R N × d k Q, K \in \mathbb{R}^{N \times d_k} Q,KRN×dk:查询(query)和键(key)张量, N N N 是序列长度, d k d_k dk 是每个头的维度(在你的代码中对应 head_dim)。
  • V ∈ R N × d v V \in \mathbb{R}^{N \times d_v} VRN×dv:值(value)张量, d v d_v dv 通常等于 d k d_k dk
  • Q K ⊤ ∈ R N × N QK^\top \in \mathbb{R}^{N \times N} QKRN×N:注意力分数矩阵,表示每个查询与每个键的点积。
  • d k \sqrt{d_k} dk :缩放因子,用于归一化点积。
  • softmax \text{softmax} softmax:按行归一化注意力分数,使其和为 1。

缩放的目的
缩放的核心目标是控制注意力分数 Q K ⊤ QK^\top QK 的方差,从而:

  • 提高数值稳定性:避免 softmax 输入过大导致的溢出或退化问题。
  • 提高训练稳定性:确保梯度分布合理,促进优化过程的收敛。

2. 为什么要控制注意力分数的方差?

注意力分数 Q K ⊤ QK^\top QK 的每个元素 s i j = q i ⊤ k j s_{ij} = \mathbf{q}_i^\top \mathbf{k}_j sij=qikj 是查询向量 q i ∈ R d k \mathbf{q}_i \in \mathbb{R}^{d_k} qiRdk 和键向量 k j ∈ R d k \mathbf{k}_j \in \mathbb{R}^{d_k} kjRdk 的点积:
s i j = ∑ k = 1 d k q i , k k j , k s_{ij} = \sum_{k=1}^{d_k} q_{i,k} k_{j,k} sij=k=1dkqi,kkj,k
在 Transformer 中,查询和键向量通常来自线性投影,其元素可以近似看作独立同分布的随机变量,均值为 0,方差为 1(例如,经过初始化的权重矩阵或归一化处理)。我们来分析点积的统计特性。

2.1 未经缩放的点积方差

假设 q i , k , k j , k ∼ N ( 0 , 1 ) q_{i,k}, k_{j,k} \sim \mathcal{N}(0, 1) qi,k,kj,kN(0,1)(均值 0,方差 1),则:

  • 每个乘积项 q i , k k j , k q_{i,k} k_{j,k} qi,kkj,k 的均值为:
    E [ q i , k k j , k ] = E [ q i , k ] ⋅ E [ k j , k ] = 0 ⋅ 0 = 0 \mathbb{E}[q_{i,k} k_{j,k}] = \mathbb{E}[q_{i,k}] \cdot \mathbb{E}[k_{j,k}] = 0 \cdot 0 = 0 E[qi,kkj,k]=E[qi,k]E[kj,k]=00=0
  • 方差为:
    Var ( q i , k k j , k ) = E [ ( q i , k k j , k ) 2 ] = E [ q i , k 2 ] ⋅ E [ k j , k 2 ] = 1 ⋅ 1 = 1 \text{Var}(q_{i,k} k_{j,k}) = \mathbb{E}[(q_{i,k} k_{j,k})^2] = \mathbb{E}[q_{i,k}^2] \cdot \mathbb{E}[k_{j,k}^2] = 1 \cdot 1 = 1 Var(qi,kkj,k)=E[(qi,kkj,k)2]=E[qi,k2]E[kj,k2]=11=1
  • 点积 s i j s_{ij} sij d k d_k dk 个独立项的和:
    s i j = ∑ k = 1 d k q i , k k j , k s_{ij} = \sum_{k=1}^{d_k} q_{i,k} k_{j,k} sij=k=1dkqi,kkj,k
    因此:
    • 均值: E [ s i j ] = ∑ k = 1 d k E [ q i , k k j , k ] = 0 \mathbb{E}[s_{ij}] = \sum_{k=1}^{d_k} \mathbb{E}[q_{i,k} k_{j,k}] = 0 E[sij]=k=1dkE[qi,kkj,k]=0
    • 方差: Var ( s i j ) = ∑ k = 1 d k Var ( q i , k k j , k ) = d k \text{Var}(s_{ij}) = \sum_{k=1}^{d_k} \text{Var}(q_{i,k} k_{j,k}) = d_k Var(sij)=k=1dkVar(qi,kkj,k)=dk

问题

  • d k d_k dk 较大(例如,64、128 或更高)时,点积 s i j s_{ij} sij 的方差为 d k d_k dk,意味着其值可能非常大(标准差为 d k \sqrt{d_k} dk )。
  • 例如,若 d k = 64 d_k = 64 dk=64,则 Var ( s i j ) = 64 \text{Var}(s_{ij}) = 64 Var(sij)=64,标准差为 64 = 8 \sqrt{64} = 8 64 =8,点积值可能在 ([-24, 24]) 或更广范围内波动(假设正态分布,约 99.7% 的值在 ± 3 σ \pm 3\sigma ±3σ 内)。

这种大方差会导致以下问题:

  1. Softmax 数值不稳定
    • Softmax 函数 softmax ( s i j ) = exp ⁡ ( s i j ) ∑ j exp ⁡ ( s i j ) \text{softmax}(s_{ij}) = \frac{\exp(s_{ij})}{\sum_j \exp(s_{ij})} softmax(sij)=jexp(sij)exp(sij) 对输入值非常敏感。
    • 如果 s i j s_{ij} sij 的绝对值较大(例如 24),则 exp ⁡ ( 24 ) ≈ 2.6 × 10 10 \exp(24) \approx 2.6 \times 10^{10} exp(24)2.6×1010,可能导致:
      • 溢出:在 FP16 或 BF16 精度下,指数值可能超出数值范围。
      • 退化:softmax 输出集中在最大值上(接近 one-hot),导致其他位置的权重接近 0,丢失信息。
  2. 训练不稳定
    • 大幅度的 s i j s_{ij} sij 会导致 softmax 后的梯度要么过大(爆炸),要么过小(消失),影响优化器的收敛。
    • 梯度不稳定会使模型难以学习复杂的模式,尤其在深层网络中。

2.2 缩放后的方差

缩放操作通过除以 d k \sqrt{d_k} dk 来归一化点积:
s i j ′ = s i j d k = q i ⊤ k j d k s_{ij}' = \frac{s_{ij}}{\sqrt{d_k}} = \frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d_k}} sij=dk sij=dk qikj

  • 均值:
    E [ s i j ′ ] = E [ s i j ] d k = 0 \mathbb{E}[s_{ij}'] = \frac{\mathbb{E}[s_{ij}]}{\sqrt{d_k}} = 0 E[sij]=dk E[sij]=0
  • 方差:
    Var ( s i j ′ ) = Var ( s i j ) d k = d k d k = 1 \text{Var}(s_{ij}') = \frac{\text{Var}(s_{ij})}{d_k} = \frac{d_k}{d_k} = 1 Var(sij)=dkVar(sij)=dkdk=1
  • 标准差:
    Var ( s i j ′ ) = 1 = 1 \sqrt{\text{Var}(s_{ij}')} = \sqrt{1} = 1 Var(sij) =1 =1

效果

  • 缩放后,注意力分数 s i j ′ s_{ij}' sij 的方差固定为 1,标准差为 1,无论 d k d_k dk 的大小。
  • 例如,若 d k = 64 d_k = 64 dk=64,未经缩放的 s i j s_{ij} sij 标准差为 8,缩放后标准差降为 1,值范围缩小到 ([-3, 3])(正态分布的 99.7% 范围)。

这使得:

  • Softmax 输入的范围更合理,指数值 exp ⁡ ( s i j ′ ) \exp(s_{ij}') exp(sij) 不会轻易溢出。
  • Softmax 输出更均匀,保留了不同键的相对权重,避免退化为 one-hot。

2.3 为什么选择 d k \sqrt{d_k} dk

选择 d k \sqrt{d_k} dk 作为缩放因子是因为它精确地抵消了点积的方差增长:

  • 点积的方差与 d k d_k dk 成正比,缩放因子 d k \sqrt{d_k} dk 将方差归一化为 1。
  • 其他缩放因子(如 d k d_k dk 或常数)可能导致方差过小(抑制信号)或过大(仍不稳定)。

例如:

  • 若缩放因子为 d k d_k dk,则 Var ( s i j ′ ) = d k d k 2 = 1 d k \text{Var}(s_{ij}') = \frac{d_k}{d_k^2} = \frac{1}{d_k} Var(sij)=dk2dk=dk1,方差过小,信号被压缩。
  • 若不缩放(因子为 1),则 Var ( s i j ′ ) = d k \text{Var}(s_{ij}') = d_k Var(sij)=dk,方差过大,导致数值问题。

d k \sqrt{d_k} dk 是一个理论上和实践上平衡的选择,广泛应用于 Transformer 模型。

3. 总结

“缩放的目的是为了控制注意力分数的方差,从而提高数值稳定性和训练稳定性”在 Scaled Dot-Product Attention 机制中的具体含义是:

  • 控制方差:点积 Q K ⊤ QK^\top QK 的方差与头维度 d k d_k dk 成正比( Var = d k \text{Var} = d_k Var=dk)。通过除以 d k \sqrt{d_k} dk ,方差归一化为 1,注意力分数的范围缩小(例如 ([-3, 3])),适配 softmax 计算。
  • 提高数值稳定性:缩放后的分数防止了 softmax 的溢出或退化(one-hot),尤其在 FP16/BF16 或高维度(如 d k = 128 d_k = 128 dk=128)场景下,减少了 NaN 或精度丢失的风险。
  • 提高训练稳定性:稳定的注意力分数分布导致更均匀的 softmax 输出,梯度分布合理,避免了爆炸或消失,促进深层网络的收敛。

通过缩放,Scaled Dot-Product Attention 能够在多种场景下(长序列、深层网络、低精度训练)保持高效和稳定。

http://www.xdnf.cn/news/593893.html

相关文章:

  • Spring Cloud生态与技术选型指南:如何构建高可用的微服务系统?
  • C语言:gcc 或 g++ 数组边界检查方法
  • 山东大学软件学院创新项目实训开发日志——第十二周
  • 2021~2025:特斯拉人形机器人Optimus发展进程详解
  • UV-python环境管理工具 入门教程
  • 时源芯微|电源、地线的处理
  • 技术篇-2.4.Python应用场景及开发工具安装
  • JMeter JDBC请求Query Type实测(金仓数据库版)
  • springboot3+vue3融合项目实战-大事件文章管理系统-本地存储及阿里云oss程序集成
  • 一文读懂Agent智能体,从概念到应用—Agent百科
  • GTM4.1-SPE
  • spring+tomcat 用户每次发请求,tomcat 站在线程的角度是如何处理用户请求的,spinrg的bean 是共享的吗
  • 练习写作对口语输出有显著的促进作用
  • Zephyr OS 中的互斥信号量
  • 高等数学-微分
  • SDWebImage源码学习
  • 容器资源绑定和查看
  • 中医方剂 - 理中汤
  • 车载网关策略 --- 车载网关重置前的请求转发机制
  • HarmonyOS学习——UIAbility组件(上)
  • 有监督学习——决策树
  • 咬合配准算法文献推荐
  • 机器学习圣经PRML作者Bishop20年后新作中文版出版!
  • Apollo10.0学习——planning模块(10)之依赖注入器injector_
  • 交换机工作原理解析与网络安全实践
  • 4个关键功能,让健康管理系统真正发挥作用
  • 基于Java的体育场馆预约系统的设计与实现【附源码】
  • Web3.0:下一代互联网的变革与机遇
  • [原创](现代Delphi 12指南):[macOS 64bit App开发]: 如何获取目标App的程序图标?
  • 论文解读 | 《桑黄提取物对小鼠宫颈癌皮下移植瘤的抑制及机制研究》