模拟注意力:少量参数放大 Attention 表征能力
论文标题
SAS: Simulated Attention Score
论文地址
https://arxiv.org/pdf/2507.07694
代码
见论文附录
作者背景
摩根士丹利,斯坦福大学,微软研究院,新加坡国立大学,得克萨斯大学奥斯汀分校,香港大学
动机
多头注意力是 Transformer 的核心组件,它通过引入多组 QKV 投影来捕获不同的特征子空间,从而在机器翻译、问答等任务中取得巨大成功。研究表明,注意力头的数量对 Transformer 性能至关重要:在保证每个头的隐藏维度充分大的前提下,注意力头数越多可以使模型效果越好。但问题在于,直接增加头数或维度往往伴随着模型参数量和计算开销的剧增,这在训练和部署中代价高昂
目前也有一些注意力架构旨在提高计算效率,例如共享部分 K 和 V 的 MQA、GQA;使用矩阵分解的 MLA、MFA、TPA 等。但这些方法主要关注降低内存/计算成本,而非提升注意力的表达能力
于是作者希望在不显著增加参数的前提下,设计一种新的注意力架构,实现近似于使用了更多注意力头和更高每头维度的性能提升
本文方法
本文提出 SAS(Simulated Attention Score,模拟注意力分数),核心思想是在注意力计算中引入额外的映射层,将低维的头表示投射到更高维空间,以此“虚拟地”增大注意力头数和每头的隐藏维度
一、扩展注意力头
对于查询Q,其特征维度为 [B, T, H, D],分别表示 batch_size,序列长度,头数和隐藏维度。为了扩充 H,需要把其他维度拉平,得到张量 Q_0,维度为 [B * T * D, H] ;然后使用一个 H * H’ 的线性变换得到 Q_1,维度为 [B * T * D, H’],其中 H’ > H;Q_1 过一个 ReLU 引入非线性;最后再过一个 H’ * H’ 的线性层,并加上 Q_1 的残差连接
于是我们获得了更多的注意力头,其中残差连接的引入可以稳定训练;值得注意的是,原始头数 H 和扩展后的头数 H’ 都远小于每头的特征维度 D,所以这个两层 MLP 的参数开销相对整模型来说可以忽略不计
除了使用 MLP 来扩展维度,作者还尝试了卷积方案。具体地,将查询 Q 的维度整理成 [B * T, H, D],类似于多通道特征图,然后使用卷积变换将 H 扩展成 H’,同样地,H’ > H,最后再过第二层卷积以及残差连接
类似地,在 K、V 中都应用上述扩展流程
二、扩展注意力维度
直觉上,每个注意力头内部特征维度 D 越大,其能够捕获的子空间信息越丰富。因此作者进一步在 Q 和 K 上也引入了类似的维度扩展映射。这里之所以不对 V 进行扩展,是因为 V
直接决定了注意力模块的输出张量隐藏维度,扩大 V 的每头维度到 D 会导致后续前馈层的参数量大幅增加,违背了不显著增加计算量的初衷
三、注意力聚合
在标准多头注意力中,会将所有头的输出向量拼接,再通过一个输出投影矩阵 O 映射回模型的隐藏维度。然而,由于 SAS 对注意力头数进行了扩增,若仍按传统方式拼接势必导致输出维度变大,进而导致 O 的参数量大大增加(H * hidden 变为 H’ * hidden)。为此,作者提出了参数高效注意力聚合机制,旨在不增加输出层参数规模的情况下完成对多头输出的整合
实现过程非常简单:假设注意力头数扩展了 r 倍,即 r * H = H’,那么便把所有头划分成 r 组,每组都按照原本的计算流程与 O 相乘,得到 r 组输出结果,最后取平均作为注意力模块的最终输出传向前馈层
实验结果
作者在多种基准任务和数据集上对SAS进行了验证,包括语言模型预训练及下游任务评估,全面展示了SAS在准确率和效率方面的优势
一、预训练效果
下图对比了SAS与标准MHA、MQA、GQA、MLA、TPA等方法在ArXiv和Books3数据集上的表现。结果表明,无论是短序列训练(长度512)还是长序列训练(长度1024),SAS均取得了最低的验证困惑度
除了取得更好的性能,SAS还加速了模型的收敛。作者报告,在 Books3 数据集、序列长度512的训练中,MHA模型在5万步时达到29.86的验证困惑度,而SAS模型在3万步时就达到了相近的30.49,即 SAS 可以节约 40% 左右的计算资源
此外,作者还在更大的训练长度、更大的模型尺寸上做了验证,结果表明相比于其他注意力机制 SAS 具备稳定的优势
二、下游任务效果
作者评测了在多个下游任务基准(ARC、HellaSwag、PIQA、ScIQ、SocialIQA、WinoGrande)上 SAS 与其他注意力模型的效果,可见在多种参数量、训练数据量的实验设置下,SAS 大部分情况下都表现出了最优性能