当前位置: 首页 > news >正文

模拟注意力:少量参数放大 Attention 表征能力

论文标题

SAS: Simulated Attention Score

论文地址

https://arxiv.org/pdf/2507.07694

代码

见论文附录

作者背景

摩根士丹利,斯坦福大学,微软研究院,新加坡国立大学,得克萨斯大学奥斯汀分校,香港大学

动机

多头注意力是 Transformer 的核心组件,它通过引入多组 QKV 投影来捕获不同的特征子空间,从而在机器翻译、问答等任务中取得巨大成功。研究表明,注意力头的数量对 Transformer 性能至关重要:在保证每个头的隐藏维度充分大的前提下,注意力头数越多可以使模型效果越好。但问题在于,直接增加头数或维度往往伴随着模型参数量和计算开销的剧增,这在训练和部署中代价高昂

在这里插入图片描述

目前也有一些注意力架构旨在提高计算效率,例如共享部分 K 和 V 的 MQA、GQA;使用矩阵分解的 MLA、MFA、TPA 等。但这些方法主要关注降低内存/计算成本,而非提升注意力的表达能力

于是作者希望在不显著增加参数的前提下,设计一种新的注意力架构,实现近似于使用了更多注意力头和更高每头维度的性能提升

本文方法

本文提出 SAS(Simulated Attention Score,模拟注意力分数),核心思想是在注意力计算中引入额外的映射层,将低维的头表示投射到更高维空间,以此“虚拟地”增大注意力头数和每头的隐藏维度

一、扩展注意力头

对于查询Q,其特征维度为 [B, T, H, D],分别表示 batch_size,序列长度,头数和隐藏维度。为了扩充 H,需要把其他维度拉平,得到张量 Q_0,维度为 [B * T * D, H] ;然后使用一个 H * H’ 的线性变换得到 Q_1,维度为 [B * T * D, H’],其中 H’ > H;Q_1 过一个 ReLU 引入非线性;最后再过一个 H’ * H’ 的线性层,并加上 Q_1 的残差连接

在这里插入图片描述

于是我们获得了更多的注意力头,其中残差连接的引入可以稳定训练;值得注意的是,原始头数 H 和扩展后的头数 H’ 都远小于每头的特征维度 D,所以这个两层 MLP 的参数开销相对整模型来说可以忽略不计

除了使用 MLP 来扩展维度,作者还尝试了卷积方案。具体地,将查询 Q 的维度整理成 [B * T, H, D],类似于多通道特征图,然后使用卷积变换将 H 扩展成 H’,同样地,H’ > H,最后再过第二层卷积以及残差连接

在这里插入图片描述

类似地,在 K、V 中都应用上述扩展流程

二、扩展注意力维度

直觉上,每个注意力头内部特征维度 D 越大,其能够捕获的子空间信息越丰富。因此作者进一步在 Q 和 K 上也引入了类似的维度扩展映射。这里之所以不对 V 进行扩展,是因为 V

直接决定了注意力模块的输出张量隐藏维度,扩大 V 的每头维度到 D 会导致后续前馈层的参数量大幅增加,违背了不显著增加计算量的初衷

在这里插入图片描述

三、注意力聚合

在标准多头注意力中,会将所有头的输出向量拼接,再通过一个输出投影矩阵 O 映射回模型的隐藏维度。然而,由于 SAS 对注意力头数进行了扩增,若仍按传统方式拼接势必导致输出维度变大,进而导致 O 的参数量大大增加(H * hidden 变为 H’ * hidden)。为此,作者提出了参数高效注意力聚合机制,旨在不增加输出层参数规模的情况下完成对多头输出的整合

实现过程非常简单:假设注意力头数扩展了 r 倍,即 r * H = H’,那么便把所有头划分成 r 组,每组都按照原本的计算流程与 O 相乘,得到 r 组输出结果,最后取平均作为注意力模块的最终输出传向前馈层

在这里插入图片描述

实验结果

作者在多种基准任务和数据集上对SAS进行了验证,包括语言模型预训练及下游任务评估,全面展示了SAS在准确率和效率方面的优势

一、预训练效果

下图对比了SAS与标准MHA、MQA、GQA、MLA、TPA等方法在ArXiv和Books3数据集上的表现。结果表明,无论是短序列训练(长度512)还是长序列训练(长度1024),SAS均取得了最低的验证困惑度

在这里插入图片描述

除了取得更好的性能,SAS还加速了模型的收敛。作者报告,在 Books3 数据集、序列长度512的训练中,MHA模型在5万步时达到29.86的验证困惑度,而SAS模型在3万步时就达到了相近的30.49,即 SAS 可以节约 40% 左右的计算资源

此外,作者还在更大的训练长度、更大的模型尺寸上做了验证,结果表明相比于其他注意力机制 SAS 具备稳定的优势

二、下游任务效果

作者评测了在多个下游任务基准(ARC、HellaSwag、PIQA、ScIQ、SocialIQA、WinoGrande)上 SAS 与其他注意力模型的效果,可见在多种参数量、训练数据量的实验设置下,SAS 大部分情况下都表现出了最优性能

在这里插入图片描述

http://www.xdnf.cn/news/1111825.html

相关文章:

  • C#与FX5U进行Socket通信
  • 【设计模式】桥接模式(柄体模式,接口模式)
  • OneCode 3.0架构深度剖析:工程化模块管理与自治UI系统的设计与实现
  • 企业商业秘密保卫战:经营信息类案件维权全攻略
  • 分布式系统高可用性设计 - 缓存策略与数据同步机制
  • wedo稻草人-----第32节(免费分享图纸)
  • 实验一 接苹果
  • LeetCode经典题解:3、无重复字符的最长子串
  • ADI的EV-21569-SOM核心板和主板转接卡的链接说明
  • Kubernetes持久卷实战
  • 13. G1垃圾回收器
  • os.loadavg()详解
  • Python 训练营打卡 Day 59-经典时序预测模型3
  • Java 大视界 -- Java 大数据机器学习模型在电商用户复购行为预测与客户关系维护中的应用(343)
  • IDEA中一个服务创建多个实例
  • 【C/C++】迈出编译第一步——预处理
  • [案例八] NX二次开发长圆孔的实现(支持实体)
  • TensorFlow2 study notes[2]
  • 【Linux网络】IP 协议详解:结构、地址与交付机制全面解析
  • 算法第三十一天:贪心算法part05(第八章)
  • Qt 多线程编程:单例任务队列的设计与实现
  • 【数据结构初阶】--顺序表(二)
  • 【读书笔记】《C++ Software Design》第一章《The Art of Software Design》
  • 【一起来学AI大模型】RAG系统组件:检索器(LangChain)
  • Python 实战:构建可扩展的命令行插件引擎
  • 试用了10款翻译软件后,我只推荐这一款!完全免费还超好用
  • 挖矿病毒判断与处理 - 入门
  • DBeaver连接MySQL8.0报错Public Key Retrieval is not allowed
  • Redis集群会有写操作丢失吗?为什么?
  • 1. 好的设计原则