当前位置: 首页 > news >正文

大模型中的三角位置编码实现

Transformer中嵌入表示 + 位置编码的实现

import torch
import math
from torch import nn# 词嵌入位置编码实现
class EmbeddingWithPosition(nn.Module):"""vocab_size:词表大小emb_size: 词向量维度seq_max_len: 句子最大长度 (人为设定,例如GPT2的最大长度是1024) """def __init__(self, vocab_size, emb_size, dropout=0.1, seq_max_len=5000):self.seq_emb = nn.Embedding(vocab_size, emb_size) # 序列中每个token的embedding向量表示#  位置编码实现 (硬编码方式)position_idx = torch.arange(0, seq_max_len, dtype=torch.float).unsqueeze(-1)position_emb_fill = position_idx * torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000.0) / emb_size) # 三角位置编码实现position_emb = torch.zeros(seq_max_len, emb_size) # 位置编码 emb_size是嵌入维度大小position_emb[:, 0::2] = torch.sin(position_emb_fill)position_emb[:, 1::2] = torch.cos(position_emb_fill)self.register_buffer('pos_encoding', position_emb) # 固定参数,不需要trainself.dropout = nn.Dropout(dropout)def forward(self, x):x = self.seq_emb(x) # 嵌入层表示 (batch_size, seq_len, emb_size)# x = x + self.pos_encoding.unsqueeze(0)[:,:x.size()[1],:] # 添加位置编码x += self.pos_encoding.unsqueeze(0)return self.dropout(x)

自己动手实现易懂版本:

assert 10 % 2 == 0,  "wrong assert"
# 如果前面判断正确的话,则不会引发异常;否则,则会引发异常import torchimport torch
def creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim):assert dim % 2 == 0, "wrong dim"position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float)omega = torch.arange(dim//2, dtype=torch.float)omega /= dim/2.omega = 1./(10000**omega)sita = n_pos_vec[:,None] @ omega[None,:]emb_sin = torch.sin(sita)emb_cos = torch.cos(sita)position_embedding[:,0::2] = emb_sinposition_embedding[:,1::2] = emb_cosreturn position_embeddingdef create_pe_absulute_sincos_embedding(n_pos_vec, dim):"""绝对位置编码:param n_pos_vec: 位置编码的长度向量:param dim: 词向量的维度:return: 位置编码"""assert dim % 2 == 0, "dim must be even"position_embedding = torch.zeros(n_pos_vec.numel(), dim, dtype=torch.float) # 三角函数位置编码omega = torch.arange(dim // 2, dtype=torch.float) # 0 ~ i, max_i: dim // 2omega *= 2omega /= dim omega = torch.pow(10000, omega) # 10000^(2i/dim)omega = 1 / omegaomega = omegaprint("n_pos_vec shape:",n_pos_vec.unsqueeze(1).shape)print("omega shape:", omega.shape).squeezeposition_embedding[:, 0::2] = torch.sin(n_pos_vec.unsqueeze(1) * omega) # 偶数位置position_embedding[:, 1::2] = torch.cos(n_pos_vec.unsqueeze(1) * omega) # 奇数位置return position_embeddingif __name__ == "__main__":n_pos = 4dim = 8n_pos_vec = torch.arange(n_pos, dtype=torch.float)position_embeddding = create_pe_absulute_sincos_embedding(n_pos_vec, dim)position_embeddding_1 = creat_pe_absolute_sincos_embedding_gov(n_pos_vec, dim)print(position_embeddding == position_embeddding_1)print("position embedding shape:", position_embeddding.shape)

参考版本

http://www.xdnf.cn/news/377281.html

相关文章:

  • PySide6 GUI 学习笔记——常用类及控件使用方法(常用类边距QMarginsF)
  • 【部署】win10的wsl环境下启动dify的web前端服务
  • 21.【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--单体转微服务--身份认证服务拆分规划
  • linux perf top分析系统性能
  • 光流 | 基于深度学习的光流估计算法汇总,原理,公式,流程图,代码
  • 人形机器人量产元年开启,AI与物理世界深度融合
  • CAS操作
  • Ceph集群故障处理 - PG不一致修复
  • [SV]等待32个instance的某一个信号的pulse,该怎么写?
  • Windows 系统 - Trae 内 终端 无法使用 node (重新配置 nodejs 路径)
  • 青藏高原东北部祁连山地区250m分辨率多年冻土空间分带指数图(2023)
  • AtCoder AT_abc405_d ABC405D - Escape Route
  • 智慧能源大数据平台建设方案(PPT)
  • 数字孪生实战笔记(1)数字孪生的含义、应用及技术体系
  • RPA 浏览器自动化:高效扩展与智能管理的未来
  • SpringBoot学习(上) , SpringBoot项目的创建(IDEA2024版本)
  • 基于阿伦尼斯模型的电池寿命预测:原理、应用与挑战
  • 数据结构:树(树的定义和基本术语)
  • JGL069垃圾填埋场模拟装置试验台
  • 力扣top100 矩阵置零
  • 近日部署跑通的若干多模态模型总结与论文概述
  • clangd与clang-tidy
  • Flutter PIP 插件 ---- 为iOS 重构PipController, Demo界面,更好的体验
  • 优选算法——前缀和
  • Java---StringJoiner 的使用
  • C++11新特性:深入解析decltype关键字及其与auto的区别
  • AI Agent(8):安全与伦理考量
  • [题解]2023CCPC黑龙江省赛 - Folder
  • 警惕C#版本差异多线程中的foreach陷阱
  • 每日c/c++题 备战蓝桥杯(P2241 统计方形(数据加强版))