当前位置: 首页 > ai >正文

自注意力(Self-Attention)和位置编码

自注意力

  • 给定序列 x 1 , … , x n \mathbf{x}_1, \ldots, \mathbf{x}_n x1,,xn, ∀ x i ∈ R d \forall \mathbf{x}_i \in \mathbb{R}^d xiRd

  • 自注意力池化层将 x i \mathbf{x}_i xi 当做key, value, query来对序列抽取特征得到 y 1 , … , y n \mathbf{y}_1, \ldots, \mathbf{y}_n y1,,yn, 这里

    y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d yi=f(xi,(x1,x1),,(xn,xn))Rd
    在这里插入图片描述
    与 CNN、RNN 的比较
    在这里插入图片描述

CNNRNN自注意力
计算复杂度O( k n d 2 knd^2 knd2)O( n d 2 nd^2 nd2)O( n 2 d n^2d n2d)
并行度O( n n n)O( 1 1 1)O( n n n)
最长路径O( n / k n/k n/k)O( n n n)O( 1 1 1)

位置编码

  • 跟CNN/RNN不同,自注意力并没有记录位置信息
  • 位置编码将位置信息注入到输入里
    • 假设长度为 n n n 的序列是 X ∈ R n × d \mathbf{X} \in \mathbb{R}^{n \times d} XRn×d,那么使用位置编码矩阵 P ∈ R n × d \mathbf{P} \in \mathbb{R}^{n \times d} PRn×d 来输出 X + P \mathbf{X} + \mathbf{P} X+P 作为自编码输入
  • P \mathbf{P} P 的元素如下计算:
    p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) , p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) p_{i,2j} = \sin\left(\frac{i}{10000^{2j/d}}\right), \quad p_{i,2j+1} = \cos\left(\frac{i}{10000^{2j/d}}\right) pi,2j=sin(100002j/di),pi,2j+1=cos(100002j/di)

位置编码矩阵

  • P ∈ R n × d \mathbf{P} \in \mathbb{R}^{n \times d} PRn×d: p i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) , p i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) p_{i,2j} = \sin\left(\frac{i}{10000^{2j/d}}\right), \quad p_{i,2j+1} = \cos\left(\frac{i}{10000^{2j/d}}\right) pi,2j=sin(100002j/di),pi,2j+1=cos(100002j/di)

相对位置信息

  • 位于 i + δ i+\delta i+δ 处的位置编码可以线性投影位置 i i i 处的位置编码来表示

  • ω j = 1 / 1000 0 2 j / d \omega_j = 1/10000^{2j/d} ωj=1/100002j/d,那么在这里插入图片描述

总结

  • 自注意力池化层将 x i \mathbf{x}_i xi 当做key, value, query来对序列抽取特征
  • 完全并行、最长序列为1、但对长序列计算复杂度高
  • 位置编码在输入中加入位置信息,使得自注意力能够记忆位置信息

代码实现

首先导入必要的环境

import math
import torch
from torch import nn
from d2l import torch as d2l

自注意力

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()

在这里插入图片描述
位置编码

#@save
class PositionalEncoding(nn.Module):"""位置编码"""def __init__(self, num_hiddens, dropout, max_len=1000):"""初始化位置编码类参数:num_hiddens: int编码的隐藏维度大小(即每个位置的编码维度)dropout: floatDropout的概率,用于防止过拟合max_len: int, 默认值为1000最大序列长度,用于生成足够长的位置编码矩阵"""super(PositionalEncoding, self).__init__()# 定义Dropout层,用于在前向传播中随机丢弃部分神经元self.dropout = nn.Dropout(dropout)# 创建一个形状为 (1, max_len, num_hiddens) 的位置编码矩阵 P# 其中 1 表示批量维度,max_len 表示序列长度,num_hiddens 表示编码维度self.P = torch.zeros((1, max_len, num_hiddens))# 生成位置索引的张量,形状为 (max_len, 1)# 每个位置索引除以 10000 的幂次,幂次由编码维度决定X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)# 对编码维度的偶数索引位置应用正弦函数self.P[:, :, 0::2] = torch.sin(X)# 对编码维度的奇数索引位置应用余弦函数self.P[:, :, 1::2] = torch.cos(X)def forward(self, X):"""前向传播函数,将位置编码添加到输入张量 X 上参数:X: torch.Tensor输入张量,形状为 (batch_size, seq_len, num_hiddens)返回:torch.Tensor添加了位置编码的张量,形状与输入张量相同"""# 将位置编码矩阵 P 的前 seq_len 个位置与输入张量 X 相加# 并将 P 移动到与 X 相同的设备(如 GPU 或 CPU)X = X + self.P[:, :X.shape[1], :].to(X.device)# 应用 Dropout 并返回结果return self.dropout(X)

行代表标记在序列中的位置,列代表位置编码的不同维度

encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])

在这里插入图片描述
在编码维度上降低频率
在这里插入图片描述

http://www.xdnf.cn/news/3889.html

相关文章:

  • Spring 中 @Value 注解实现原理
  • Vim 命令从头学习记录
  • 笔记本电脑升级计划(2017———2025)
  • JavaScript 笔记 --- part8 --- JS进阶 (part3)
  • 【NLP】32. Transformers (HuggingFace Pipelines 实战)
  • 全球化电商平台Azure云架构设计
  • 【计网】交换机和集线器对比
  • java学习之数据结构:四、树(代码补充)
  • 【Spring Boot】Spring Boot + Thymeleaf搭建mvc项目
  • flink rocksdb状态说明
  • 阿里云物联网平台--云产品流传
  • 7、Activiti-任务类型
  • 如何快速获取字符串的UTF-8或UTF-16编码二进制数据?数值转换成字符串itoa不是C标准?其它类型转换成字符串?其它类型转换成数值类型?
  • 虚幻引擎作者采访
  • 2.在Openharmony写hello world
  • 蓝桥杯 18. 积木
  • 记9(Torch
  • Leetcode刷题记录32——搜索二维矩阵 II
  • Dubbo(97)如何在物联网系统中应用Dubbo?
  • C语言 ——— 函数
  • Java设计模式: 工厂模式与策略模式
  • COlT_CMDB_linux_tomcat_20250505.sh
  • 【AI大模型】SpringBoot整合Spring AI 核心组件使用详解
  • 基于大模型的子宫腺肌病全流程预测与诊疗方案研究报告
  • 定位理论第一法则在医疗AI编程中的应用
  • Linux /dev/null文件用法介绍
  • 【KWDB 创作者计划】KWDB 2.2.0多模融合架构与分布式时序引擎
  • 如何选择合适的光源?
  • 【Linux网络#17】TCP全连接队列与tcpdump抓包
  • Linux55yum源配置、本机yum源备份,本机yum源配置,网络Yum源配置,自建yum源仓库