当前位置: 首页 > news >正文

多头注意力(Multi‑Head Attention)

1. 多头注意力(Multi‑Head Attention)原理

设输入序列表示为矩阵 X ∈ R B × L × d model X\in\mathbb{R}^{B\times L\times d_{\text{model}}} XRB×L×dmodel,其中

  • B B B:批大小(batch size),
  • L L L:序列长度(sequence length),
  • d model d_{\text{model}} dmodel:模型隐层维度(model dimension)。

多头注意力基于对缩放点乘注意力的并行化扩展,引入了 h h h 个注意力头(heads),每个头在不同子空间中学习不同的表示。

1.1 线性映射与切分

我们首先为每个头定义三组可学习权重:
W i Q ∈ R d model × d k , W i K ∈ R d model × d k , W i V ∈ R d model × d v , i = 1 , … , h W_i^Q \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^K \in \mathbb{R}^{d_{\text{model}}\times d_k},\quad W_i^V \in \mathbb{R}^{d_{\text{model}}\times d_v}, \quad i=1,\dots,h WiQRdmodel×dk,WiKRdmodel×dk,WiVRdmodel×dv,i=1,,h
其中

  • h h h:头数(number of heads),
  • d k d_k dk:每个头中 Query/Key 的维度(key/query dimension),
  • d v d_v dv:每个头中 Value 的维度(value dimension),
  • 通常 d model = h × d k d_{\text{model}}=h\times d_k dmodel=h×dk 且取 d v = d k d_v = d_k dv=dk

对输入 X X X 进行投影,得到第 i i i 个头的查询、键、值:
Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X\,W_i^Q,\quad K_i = X\,W_i^K,\quad V_i = X\,W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中

  • Q i ∈ R B × L × d k Q_i \in \mathbb{R}^{B\times L\times d_k} QiRB×L×dk
  • K i ∈ R B × L × d k K_i \in \mathbb{R}^{B\times L\times d_k} KiRB×L×dk
  • V i ∈ R B × L × d v V_i \in \mathbb{R}^{B\times L\times d_v} ViRB×L×dv

1.2 缩放点乘注意力

对第 i i i 个头,分别对所有位置做点积注意力:

  1. 打分矩阵
    S c o r e i = Q i K i ⊤ ∈ R B × L × L \mathrm{Score}_i = Q_i\,K_i^\top \quad\in\mathbb{R}^{B\times L\times L} Scorei=QiKiRB×L×L
  2. 缩放
    S c o r e ~ i = S c o r e i d k \widetilde{\mathrm{Score}}_i = \frac{\mathrm{Score}_i}{\sqrt{d_k}} Score i=dk Scorei
  3. Softmax 归一化
    A i = s o f t m a x ( S c o r e ~ i ) ∈ R B × L × L A_i = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}_i\bigr) \quad\in\mathbb{R}^{B\times L\times L} Ai=softmax(Score i)RB×L×L
  4. 加权求和
    h e a d i = A i V i ∈ R B × L × d v \mathrm{head}_i = A_i\,V_i \quad\in\mathbb{R}^{B\times L\times d_v} headi=AiViRB×L×dv

1.3 拼接与线性变换

将所有头的输出在最后一维拼接,再做一次线性投影:
C o n c a t = [ h e a d 1 , … , h e a d h ] ∈ R B × L × ( h d v ) \mathrm{Concat} = \bigl[\mathrm{head}_1,\dots,\mathrm{head}_h\bigr] \quad\in\mathbb{R}^{B\times L\times (h\,d_v)} Concat=[head1,,headh]RB×L×(hdv)
定义输出权重矩阵
W O ∈ R ( h d v ) × d model W^O\in\mathbb{R}^{(h\,d_v)\times d_{\text{model}}} WOR(hdv)×dmodel
最终输出:
M u l t i H e a d ( X ) = C o n c a t W O ∈ R B × L × d model \mathrm{MultiHead}(X) = \mathrm{Concat}\;W^O \quad\in\mathbb{R}^{B\times L\times d_{\text{model}}} MultiHead(X)=ConcatWORB×L×dmodel


2. PyTorch 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass MultiHeadAttention(nn.Module):def __init__(self, d_model: int, h: int):"""d_model: 模型维度 d_modelh: 注意力头数 h"""super().__init__()assert d_model % h == 0, "d_model 必须能被 h 整除"self.d_model = d_model      # d_modelself.h = h                  # hself.d_k = d_model // h     # d_k = d_model / hself.d_v = self.d_k         # d_v 通常等于 d_k# 投影矩阵 W_i^Q, W_i^K, W_i^V,实际上合并为一个大矩阵后在 forward 再切分self.w_q = nn.Linear(d_model, d_model)  # 同时生成 h 个 Q 投影self.w_k = nn.Linear(d_model, d_model)  # 同时生成 h 个 K 投影self.w_v = nn.Linear(d_model, d_model)  # 同时生成 h 个 V 投影# 输出线性变换 W^Oself.w_o = nn.Linear(d_model, d_model)def forward(self, X: torch.Tensor, mask: torch.Tensor = None):"""X: 输入张量,形状 (B, L, d_model)mask: 可选掩码,形状 (B, 1, L, L) 或 (B, L, L)"""B, L, _ = X.size()# 1. 线性映射到 Q, K, V,然后切分 h 头#    先得到 (B, L, h*d_k),再 view/transpose 为 (B, h, L, d_k)Q = self.w_q(X).view(B, L, self.h, self.d_k).transpose(1, 2)K = self.w_k(X).view(B, L, self.h, self.d_k).transpose(1, 2)V = self.w_v(X).view(B, L, self.h, self.d_k).transpose(1, 2)# 此时 Q, K, V 形状均为 (B, h, L, d_k)# 2. 计算点积注意力#    scores = Q @ K^T  -> (B, h, L, L)scores = torch.matmul(Q, K.transpose(-2, -1))  #    缩放:除以 sqrt(d_k)scores = scores / math.sqrt(self.d_k)#    可选掩码:将被屏蔽位置设为 -inf if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))#    Softmax 归一化 -> (B, h, L, L)A = F.softmax(scores, dim=-1)#    加权求和 -> head_i 形状 (B, h, L, d_k)heads = torch.matmul(A, V)# 3. 拼接 h 个头:transpose 回 (B, L, h, d_k) 再 reshapeconcat = heads.transpose(1, 2).contiguous().view(B, L, self.h * self.d_k)#    concat 形状 (B, L, h*d_k) == (B, L, d_model)# 4. 最后一次线性变换 W^Ooutput = self.w_o(concat)  # -> (B, L, d_model)return output, A  # 返回输出及注意力权重 A
http://www.xdnf.cn/news/90325.html

相关文章:

  • 鸣潮赞妮技能机制解析 鸣潮赞妮配队推荐
  • 路由交换网络专题 | 第六章 | OSPF | BGP | BGP属性 | 防环机制
  • RS232借助 Profinet网关与调制解调器碰撞出的火花
  • 探秘云原生架构:概念、技术、设计与反模式深度解读
  • strlen参数不匹配编译报错处理
  • 前端做模糊查询(含AI版)
  • 操作系统——堆与栈详解:内存结构全面科普
  • 电商平台比价 API 接口,避免人工比价的低效与误差
  • Mellanox网卡qos设置
  • window如何关闭指定端口
  • 嵌入式人工智能应用-第三章 opencv操作8 图像特征之LBP特征 下
  • 【C++游戏引擎开发】第20篇:基于物理渲染(PBR)——辐射度量学
  • 如何一键提取多个 PPT 幻灯片中的备注到 TXT 记事本文件中
  • 爱普生FC-12M晶振在车载系统中广泛应用
  • Spring事件机制,如何使用Spring事件监听器
  • Vue 实例 VM 访问属性
  • 【MySQL】索引失效问题详解
  • STM32单片机入门学习——第46节: [14-1] WDG看门狗
  • 怎样用 esProc 提速主子表关联时的 EXISTS
  • 利用参考基因组fa和注释文件gff提取蛋白编码序列
  • 定义python中的函数和类
  • SVT-AV1编码器中的模块
  • 如何收集用户白屏/长时间无响应/接口超时问题
  • linux命令集
  • 来啦,烫,查询达梦表占用空间
  • SVT-AV1编码器初始化函数
  • Linux 系统监控基石:top 命令详解与实战指南
  • 华为仓颉编程语言基础概述
  • JavaFX深度实践:从零构建高级打地鼠游戏(含多物品与反馈机制)
  • Windows7升级Windows10,无法在此驱动器上安装Windows