当前位置: 首页 > ai >正文

大模型常用位置编码方式

深度学习中常见的位置编码方式及其Python实现:


一、固定位置编码(Sinusoidal Positional Encoding)
原理
通过不同频率的正弦和余弦函数生成位置编码,使模型能够捕捉绝对位置和相对位置信息。公式为:

公式标准数学表达
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d model ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d model ) \begin{aligned} PE_{(pos,2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \\ PE_{(pos,2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \end{aligned} PE(pos,2i)PE(pos,2i+1)=sin(100002i/dmodelpos)=cos(100002i/dmodelpos)

公式解析

  1. 变量定义
    pos:token在序列中的绝对位置(从0开始计数)

    i:位置编码向量的维度索引(范围:0 ≤ i < d_model/2)

    d_model:模型嵌入维度(如Transformer默认的512)

  2. 核心设计
    • 交替使用正弦/余弦:偶数维度用正弦函数,奇数维度用余弦函数,形成周期性编码。

    • 频率衰减特性:维度越高(i增大),分母指数项 2 i d model \frac{2i}{d_{\text{model}}} dmodel2i越大,导致频率 1000 0 − 2 i / d model 10000^{-2i/d_{\text{model}}} 100002i/dmodel越小,编码的周期性波长越长。

    • 位置唯一性:每个位置pos的编码向量唯一,且相邻位置的编码差异与相对距离成比例。

  3. 数学特性
    • 相对位置捕捉:通过三角恒等式,任意两个位置的编码内积仅与相对距离pos_i - pos_j相关,隐含相对位置信息。

    • 外推能力:周期性设计使模型能处理超过训练时最大长度的序列。

关键参数作用

参数作用示例值(以BERT-base为例)
d_model定义编码维度,影响模型容量768
10000控制频率衰减速度,值越大高频分量衰减越快固定超参数
pos序列位置索引输入序列的第0/1/2…位

  • Python实现(方式一)
import numpy as np
import matplotlib.pyplot as pltdef sinusoidal_position_encoding(max_len, d_model):pe = np.zeros((max_len, d_model))position = np.arange(max_len)[:, np.newaxis]div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))pe[:, 0::2] = np.sin(position * div_term)pe[:, 1::2] = np.cos(position * div_term)return pe# 示例
max_len, d_model = 50, 64
pe = sinusoidal_position_encoding(max_len, d_model)# 可视化
plt.imshow(pe, cmap='viridis', aspect='auto')
plt.title("Sinusoidal Position Encoding")
plt.colorbar()
plt.show()

输出示例:生成一个形状为 (50, 64) 的编码矩阵,低频维度变化平缓,高频维度变化剧烈。

  • Pytorch实现
import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):def __init__(self, d_model: int, max_len: int = 512):super().__init__()# 初始化位置编码矩阵为可训练参数 (1, max_len, d_model)self.pe = nn.Parameter(torch.empty(1, max_len, d_model))# 正态分布初始化(标准差0.02,与Transformer常规初始化一致)nn.init.normal_(self.pe, mean=0.0, std=0.02)def forward(self, x: torch.Tensor) -> torch.Tensor:"""输入x形状: [batch_size, seq_len, d_model]输出形状: [batch_size, seq_len, d_model]"""seq_len = x.size(1)# 取前seq_len个位置编码(避免越界)position_emb = self.pe[:, :seq_len, :]# 将位置编码与输入相加return x + position_emb

二、可学习位置编码(Learnable Positional Encoding)
原理
将位置编码作为可训练参数,通过嵌入层动态学习每个位置的表示。

Python实现(PyTorch)

import torch
import torch.nn as nnclass LearnablePositionalEncoding(nn.Module):def __init__(self, max_len, d_model):super().__init__()self.pe = nn.Embedding(max_len, d_model)def forward(self, x):batch_size, seq_len = x.size(0), x.size(1)positions = torch.arange(seq_len, device=x.device).expand(batch_size, seq_len)return x + self.pe(positions)# 示例
d_model, max_len = 64, 50
inputs = torch.randn(32, max_len, d_model)  # 模拟输入 (batch_size=32, seq_len=50)
pe_layer = LearnablePositionalEncoding(max_len, d_model)
output = pe_layer(inputs)
print("Encoded shape:", output.shape)  # 输出:torch.Size([32, 50, 64])

优势:灵活性高,适合特定任务;缺点:依赖预定义的最大序列长度。


三、相对位置编码(Relative Positional Encoding)
原理
关注序列元素之间的相对位置差异,常用于长序列建模。

Python实现

class RelativePositionalEncoding(nn.Module):def __init__(self, max_rel_pos, d_model):super().__init__()self.emb = nn.Embedding(2 * max_rel_pos + 1, d_model)def forward(self, seq_len):# 生成相对位置索引矩阵(对称)rel_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)rel_pos = torch.clamp(rel_pos + seq_len - 1, 0, 2 * seq_len - 2)return self.emb(rel_pos)# 示例
d_model, max_rel_pos = 64, 10
rel_pe = RelativePositionalEncoding(max_rel_pos, d_model)
rel_enc = rel_pe(seq_len=5)
print("Relative encoding shape:", rel_enc.shape)  # 输出:torch.Size([5, 5, 64])

应用场景:Transformer-XL、音乐生成等长序列任务。


四、旋转位置编码(Rotary Positional Encoding, RoPE)
原理
通过旋转矩阵将绝对位置信息融入注意力计算,保持相对位置的线性性质。
旋转矩阵公式的标准数学表达式及解析:

标准数学公式
R θ , m = [ cos ⁡ ( m θ ) − sin ⁡ ( m θ ) sin ⁡ ( m θ ) cos ⁡ ( m θ ) ] , q ′ = R θ , m q , k ′ = R θ , n k R_{\theta,m} = \begin{bmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{bmatrix}, \quad q' = R_{\theta,m}q, \quad k' = R_{\theta,n}k Rθ,m=[cos(mθ)sin(mθ)sin(mθ)cos(mθ)],q=Rθ,mq,k=Rθ,nk


Python实现

def rotate_half(x):x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat((-x2, x1), dim=-1)def apply_rotary_emb(q, k, freq):cos, sin = freq.cos(), freq.sin()q_rot = q * cos + rotate_half(q) * sink_rot = k * cos + rotate_half(k) * sinreturn q_rot, k_rot# 示例
d_model, seq_len = 64, 50
q = torch.randn(1, seq_len, d_model)
k = torch.randn(1, seq_len, d_model)
freq = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
q_rot, k_rot = apply_rotary_emb(q, k, freq)
print("Rotated shapes:", q_rot.shape, k_rot.shape)  # 输出:torch.Size([1, 50, 64])

优势:支持任意长度外推,广泛用于LLaMA、ChatGLM等大模型。


总结与选择建议

方法适用场景优点缺点
固定位置编码通用NLP任务确定性,无需训练无法自适应长序列
可学习位置编码短序列任务灵活性高依赖预定义长度,泛化性差
相对位置编码长文本生成、音乐建模捕捉相对位置关系计算复杂度较高
旋转位置编码大语言模型(LLaMA等)支持外推,数学性质优雅实现较复杂
http://www.xdnf.cn/news/5788.html

相关文章:

  • 信息论14:从互信息到信息瓶颈——解锁数据压缩与特征提取的秘密
  • 分析Docker容器Jvm 堆栈GC信息
  • 【简单易懂】SSE 和 WebSocket(Java版)
  • 删除购物车中一个商品
  • Unity
  • KMDA-6920成功助力印度智慧钢厂SCADA系统,打造高效可靠的生产监控平台
  • 菜狗的脚步学习
  • 【android bluetooth 框架分析 02】【Module详解 7】【VendorSpecificEventManager 模块介绍】
  • 前端开发避坑指南:React 代理配置常见问题与解决方案
  • BFS算法篇——打开智慧之门,BFS算法在拓扑排序中的诗意探索(上)
  • 机器学习——聚类算法练习题
  • [Java实战]Spring Boot 3构建 RESTful 风格服务(二十)
  • java使用 FreeMarker 模板生成包含图片的 `.doc` 文件
  • RustDesk:开源电脑远程控制软件
  • 端侧智能重构智能监控新路径 | 2025 高通边缘智能创新应用大赛第三场公开课来袭!
  • 霍夫圆变换全面解析(OpenCV)
  • 6. 多列布局/用户界面 - 杂志风格文章布局
  • 手机换IP真的有用吗?可以干什么?
  • spark-local模式
  • WM_TIMER定时器消息优先级低,可能会被系统丢弃,导致定时任务无法正常执行
  • T-BOX硬件方案深度解析:STM32与SD NAND Flash存储的完美搭配
  • Linux中find命令用法核心要点提炼
  • spark-standalone
  • http断点续传
  • Games101作业四
  • 在Ubuntu服务器上部署Label Studio
  • 从SAM看交互式分割与可提示分割的区别与联系:Interactive Segmentation Promptable Segmentation
  • Java基础(IO)
  • Android Native 之 自定义进程
  • 【氮化镓】电子辐照下温度对GaN位移阈能的影响