正余弦位置编码和RoPE位置编码
绝对位置编码
-
绝对位置编码:为每个位置赋予独一无二的编码向量,来明确表示输入序列中每个元素的绝对位置。像 BERT 中采用的正弦余弦位置编码,会根据位置索引,利用特定的正弦和余弦函数计算出对应位置向量的每个维度值。位置pos在维度 2 i 2i 2i的值是 s i n ( p o s / 1000 0 2 i / d m o d e l ) sin(pos/10000^{2i/d_{model}}) sin(pos/100002i/dmodel),在维度 2 i + 1 2i + 1 2i+1的值是 c o s ( p o s / 1000 0 2 i / d m o d e l ) cos(pos/10000^{2i/d_{model}}) cos(pos/100002i/dmodel) ,其中 d m o d e l d_{model} dmodel是模型维度。这种编码方式让模型可以直接学习到每个位置对应的特征表示。
-
优势:绝对位置编码为每个位置赋予唯一编码向量,能清晰表达元素的绝对位置。在文本分类、情感分析等任务中,模型可依据绝对位置编码快速定位关键信息,做出更准确的决策。而且绝对位置编码计算简单,易于实现,在 Transformer 架构早期被广泛应用,稳定性和可解释性好,有助于理解模型如何利用位置信息进行学习和预测。
-
劣势:绝对位置编码对序列长度敏感,训练和推理时序列长度需一致。若推理时序列长度与训练时不同,需重新计算编码,甚至重新训练模型,这限制了模型的应用灵活性。在实际应用中,文本长度多变,绝对位置编码难以适应这种变化。同时,绝对位置编码在处理长距离依赖时表现欠佳,模型难以捕捉相距较远元素间的关系,影响对复杂语义结构的理解,在长文本阅读理解任务中可能无法充分挖掘文本含义。
"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""
import torch
from torch import nnclass PositionalEncoding(nn.Module):"""compute sinusoid encoding."""def __init__(self, d_model, max_len, device):"""constructor of sinusoid encoding class:param d_model: dimension of model:param max_len: max sequence length:param device: hardware device setting"""super(PositionalEncoding, self).__init__()# same size with input matrix (for adding with input matrix)self.encoding = torch.zeros(max_len, d_model, device=device)self.encoding.requires_grad = False # we don't need to compute gradientpos = torch.arange(0, max_len, device=device)pos = pos.float().unsqueeze(dim=1)# 1D => 2D unsqueeze to represent word's position_2i = torch.arange(0, d_model, step=2, device=device).float()# 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])# "step=2" means 'i' multiplied with two (same with 2 * i)self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))# compute positional encoding to consider positional information of wordsdef forward(self, x):# self.encoding# [max_len = 512, d_model = 512]batch_size, seq_len = x.size()# [batch_size = 128, seq_len = 30]return self.encoding[:seq_len, :]# [seq_len = 30, d_model = 512]# it will add with tok_emb : [128, 30, 512]
相对位置编码
- 相对位置编码:重点关注序列中元素之间的相对距离或顺序关系,而非绝对位置。例如 RoPE(Rotary Position Embedding),它基于旋转矩阵的概念,通过对位置向量进行旋转操作,使得模型在计算注意力时,能够自动捕捉到元素间的相对位置信息。在计算注意力分数时,会考虑查询向量和键向量的相对位置,让模型能更有效地利用相对位置关系。
- 优势:相对位置编码聚焦于元素间的相对距离和顺序关系,能有效捕捉上下文信息。在处理长文本时,模型可利用相对位置编码把握长距离依赖,理解词汇间的语义关联。在机器翻译任务中,模型能依据相对位置编码,更准确地处理源语言和目标语言间的词序差异,提高翻译质量。同时,相对位置编码对序列长度变化适应性强,在不同长度输入下表现稳定,可直接应用于训练时未见过的序列长度,无需调整,提升了模型的泛化能力。
- 劣势:相对位置编码缺乏对元素绝对位置的直接表达,对于某些依赖绝对位置信息的任务,可能无法提供足够信息。在文本分类任务中,若关键信息的位置固定,模型难以利用相对位置编码准确捕捉该信息。此外,相对位置编码的计算和实现更为复杂,像 RoPE 需引入旋转矩阵操作,增加了计算复杂度和模型训练难度,也可能导致训练时间延长和资源消耗增加。
旋转编码通过在二维平面上旋转来变换特征对。也就是说,它将d个特征组织为d/2对。每个对可以被认为是二维平面中的一个坐标,编码将根据标记的位置将其旋转一个角度。
class RotaryPositionalEmbeddings(nn.Module):def__init__(self, d: int, base: int = 10_000):super().__init__()self.base = baseself.d = dself.cos_cached = Noneself.sin_cached = Nonedef _build_cache(self, x: torch.Tensor):if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:returnseq_len = x.shape[0]theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]idx_theta = torch.einsum('n,d->nd', seq_idx, theta) #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]self.cos_cached = idx_theta2.cos()[:, None, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]self.sin_cached = idx_theta2.sin()[:, None, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]def _neg_half(self, x: torch.Tensor):d_2 = self.d // 2 #return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]def forward(self, x: torch.Tensor):self._build_cache(x)neg_half_x = self._neg_half(x)x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]return x_rope