当前位置: 首页 > news >正文

Transformer之多头注意力机制和位置编码(二)

Transformer之多头注意力机制和位置编码(二)

文章目录

  • Transformer之多头注意力机制和位置编码(二)
  • 一、 多头注意力(Multi-Head Attention)
    • 案例
  • 二、位置编码(Positional Encoding)
    • 2.1 固定正弦/余弦
    • 2.2 可学习编码


一、 多头注意力(Multi-Head Attention)

先把整句映射成 Q、K、V,再按列切分 → 多头并行计算

核心公式

headh=Attention(QWhQ,KWhK,VWhV)MultiHead(Q,K,V)=[head1;…;headH]WO\begin{aligned} \mathrm{head}_h &= \mathrm{Attention}(Q W^Q_h,\,K W^K_h,\,V W^V_h) \\[2pt] \mathrm{MultiHead}(Q,K,V) &= [\mathrm{head}_1;\dots;\mathrm{head}_H]W^O \end{aligned} headhMultiHead(Q,K,V)=Attention(QWhQ,KWhK,VWhV)=[head1;;headH]WO

步骤速览

  1. 分头映射dim=512, head_num=8 → 每头 d_k=64
    head_num, d_k = 8, dim // 8
    W_Q = nn.Linear(dim, dim)
    Q_h = W_Q(x).view(b, seq, head_num, d_k).transpose(1, 2)  # [b, 8, seq, 64]
    # K_h, V_h 同理
    
  2. 并行注意力(缩放点积 + softmax)
  3. 拼接 + 线性
    out = out.transpose(1, 2).contiguous().view(b, seq, dim)
    out = nn.Linear(dim, dim)(out)          # 最终输出 [b, seq, 512]
    
2 个头 = 2 组独立 (Q,K,V) 子空间

表达能力
不同头关注不同模式(句法、语义、指代…),组合后更灵活。

案例

下面给出可直接复用的“多头注意力”最小实现。

特点

  • nn.Module 封装,方便后续放进 nn.Sequential 或 Transformer;
  • 逐行中文注释,一眼看懂每一步在干什么;
  • 输出维度与原句向量一致 [batch, seq_len, dim],后续可继续堆叠。
import math
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiHeadAttention(nn.Module):"""简化版多头注意力dim        : 模型总维度(你的例子里是 256)head_num   : 头数(你的例子里是 16)输出形状   : [batch, seq_len, dim]"""def __init__(self, dim: int = 256, head_num: int = 16):super().__init__()# assert :“必须保证 dim 能被 head_num 整除,否则就报错。”assert dim % head_num == 0 self.dim = dimself.head_num = head_numself.d_k = dim // head_num          # 每个头的维度# 3 个线性层一次性把 Q/K/V 投影出来(比 ModuleList 更简洁)self.W_q = nn.Linear(dim, dim)self.W_k = nn.Linear(dim, dim)self.W_v = nn.Linear(dim, dim)# 最后的输出线性变换self.W_o = nn.Linear(dim, dim)def forward(self, x):"""x : [batch, seq_len, dim]return : 与 x 形状相同"""batch, seq_len, _ = x.shape# 1) 线性投影 → [batch, seq_len, dim]Q = self.W_q(x)K = self.W_k(x)V = self.W_v(x)# 2) 拆成多头 → [batch, head_num, seq_len, d_k]def reshape(t):return t.view(batch, seq_len, self.head_num, self.d_k).transpose(1, 2)Q = reshape(Q)   # [batch, head_num, seq_len, d_k]K = reshape(K)V = reshape(V)# 3) 缩放点积注意力scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B,h,seq,seq]attn = F.softmax(scores, dim=-1)                                      # 归一化out = torch.matmul(attn, V)                                           # [B,h,seq,d_k]# 4) 合并多头 → [batch, seq_len, dim]out = out.transpose(1, 2).contiguous().view(batch, seq_len, self.dim)# 5) 最后线性变换return self.W_o(out)# ------------------ 测试 ------------------
if __name__ == "__main__":sentences = ["i am an NLPer"]vocab = sorted({w for sent in sentences for w in sent.split()})word2idx = {w: i for i, w in enumerate(vocab)}indices = torch.tensor([[word2idx[w] for w in sent.split()] for sent in sentences])dim = 256embedding = nn.Embedding(len(vocab), dim)x = embedding(indices)                  # [1, 4, 256]mha = MultiHeadAttention(dim=256, head_num=16)y = mha(x)                              # [1, 4, 256]print("输入:", x.shape)print("输出:", y.shape)

运行结果

输入: torch.Size([1, 4, 256])
输出: torch.Size([1, 4, 256])

二、位置编码(Positional Encoding)

自注意力本身“无顺序”,需显式注入位置信号。

2.1 固定正弦/余弦

给定位置 i 与维度 2j / 2j+1
PE(i,2j)=sin⁡⁣(i100002j/d)PE(i,2j+1)=cos⁡⁣(i100002j/d)\begin{aligned} PE(i,2j) &= \sin\!\left(\dfrac{i}{10000^{2j/d}}\right) \\[4pt] PE(i,2j+1) &= \cos\!\left(\dfrac{i}{10000^{2j/d}}\right) \end{aligned} PE(i,2j)PE(i,2j+1)=sin(100002j/di)=cos(100002j/di)

  • 例子:句子 <BOS> 我 喜欢 自然语言 处理(N=5, d=512)
    计算得到的 5×512 位置矩阵与词嵌入逐位相加即可。
    说明:

  • iii 是序列中位置的索引(从 000 开始)。

  • jjj 是词向量的维度索引(从 000d/2−1d/2 - 1d/21)。

  • 100001000010000 是一个超参数,用于控制频率的衰减。

句子长度为 555,编码向量维数 D=4D=4D=4
  • 外推特性
    已知 PE(pos)PE(pos)PE(pos) 可线性组合得到 PE(pos+k)PE(pos+k)PE(pos+k),模型可处理比训练集更长的句子。
505050 个词嵌入,维度 512512512 的位置编码热力图

2.2 可学习编码

直接把位置当 token 训,表现好但依赖最大长度超参。


小结
多头 = “多组独立子空间”并行注意力;
位置编码 = “给并行计算加上顺序感”。二者配合让 Transformer 既能并行又能保持序列有序。

http://www.xdnf.cn/news/1295371.html

相关文章:

  • vue更改style
  • 双椒派E2000D网络故障排查指南
  • 【Linux】库制作与原理
  • 2025年5月架构设计师综合知识真题回顾,附参考答案、解析及所涉知识点(三)
  • 苹果正计划大举进军人工智能硬件领域
  • 解决EKS中KEDA访问AWS SQS权限问题:完整的IRSA配置指南
  • 能源行业数字化转型:边缘计算网关在油田场景的深度应用
  • 支持pcm语音文件缓存顺序播放
  • 从感知到执行:人形机器人低延迟视频传输与多模态同步方案解析
  • Python 类元编程(导入时和运行时比较)
  • 【Linux学习|黑马笔记|Day3】root用户、查看权限控制信息、chmod、chown、快捷键、软件安装、systemctl、软连接、日期与时区
  • 17. 如何判断一个对象是不是数组
  • 技术速递|使用 AI Toolkit 构建基于 gpt-oss-20b 的应用程序
  • 工业元宇宙:迈向星辰大海的“玄奘之路”
  • 【Linux】常用命令(三)
  • Python 元类基础:从理解到应用的深度解析
  • PG靶机 - PayDay
  • 当img占不满div时,图片居中显示,两侧加当前图片模糊效果
  • 【Docker项目实战】使用Docker部署todo任务管理器
  • javaswing json格式化工具
  • 【2025】Datawhale AI夏令营-多模态RAG-Task3笔记-解决方案进阶
  • Redis7学习——Redis的十大类型String、List、Hash、Set、Zset
  • 模式设计:策略模式及其应用场景
  • Linux学习-UI技术
  • Python【算法中心 03】Docker部署Django搭建的Python应用流程实例(Docker离线安装配置+Django项目Docker部署)
  • Coze Studio 概览(十)--文档处理详细分析
  • 51单片机-51单片机最小系统
  • Java Stream API 中常用方法复习及项目实战示例
  • 普通电脑与云电脑的区别有哪些?全面科普
  • Apifox精准定义复杂API参数结构(oneOf/anyOf/allOf)