当前位置: 首页 > ai >正文

手搓transformer

transformer_embedding过程

用到的一些工具:

import torch
from torch import nn
import torch.nn.functional as f
import math
from torch import Tensorrandom_torch = torch.rand(3,4)#生成一个3行4列的随机矩阵(测试torch功能)
print(random_torch)

token_embedding的实现(用于将每个字转换为向量)

#nn.Eembedding是pytorch中的一个类,用于将输入的索引转换为对应的词向量

class TokenEmbedding(nn.Embedding):def __init__(self,vocab_size,d_model):  #vocab_size是词汇表的大小,d_model是词向量的维度"""super()函数用于调用父类的构造函数,这里调用的是nn.Eembedding的构造函数,初始化嵌入层"""super().__init__(vocab_size,d_model,padding_idx=1)   #padding_idx=1表示索引为1的元素是填充元素

position_embedding的实现

class PositionalEmbedding(nn.Module):def __init__(self,d_model,max_len,device=None):  #max_len是最大序列长度,device是设备可以是cpu或者gpusuper().__init__()self.encoding = torch.zeros(max_len,d_model,device=device)  #创建一个max_len行d_model列的全0矩阵,用于存储位置编码self.encoding.requires_grad = False  #将矩阵的requires_grad属性设置为False,表示不需要计算梯度"""位置编码的作用是给输入序列中的每个位置添加一个位置信息,这样模型才能知道每个位置的相对位置。这些位置编码是预先固定好的,训练过程中不需要更新,所以不需要计算梯度。"""pos = torch.arange(0,max_len,device=device)  #创建一个从0到max_len-1的一维张量,表示序列中每个位置的索引pos = pos.float().unsqueeze(dim=1)  #在维度 1 上增加一个维度,将pos张量的维度从1扩展到2,这样就可以和后面的张量进行广播操作了_2i = torch.arange(0,d_model,step=2,device=device).float()  #创建一个从0到d_model-1的一维张量,步长为2,表示词向量的偶数纬度#以下是计算位置编码的公式self.encoding[:,0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))  #计算偶数位置的位置编码self.encoding[:,1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))  #计算奇数位置的位置编码#以下是前向传播函数def forward(self,x):batch_size,seq_len = x.size()   #获取输入张量的形状return self.encoding[:seq_len,:]    #返回位置编码张量的前seq_len行,即输入序列的位置编码

transformer的实现

class Transformer(nn.Module):def __init__(self,vocab_size,d_model,max_len,drop_prob,device=None):super().__init__()self.tok_emb = TokenEmbedding(vocab_size,d_model)  #创建一个TokenEmbedding对象,用于将输入的索引转换为对应的词向量self.pos_emb = PositionalEmbedding(d_model,max_len,device)  #创建一个PositionalEmbedding对象,用于给输入序列中的每个位置添加一个位置信息self.dropout = nn.Dropout(drop_prob)  #创建一个Dropout层,用于随机将输入张量中的一些元素置为0,防止过拟合def forward(self,x):tok_emb = self.tok_emb(x)  #将输入的索引转换为对应的词向量pos_emb = self.pos_emb(x)  #给输入序列中的每个位置添加一个位置信息return self.dropout(tok_emb + pos_emb)  #将词向量和位置编码相加,并通过Dropout层进行随机置0,返回最终的输入张量

举例理解

假设词汇表:猫=0, 吃=1, 鱼=2
输入序列:x = [[0, 1, 2]](batch_size=1, seq_len=3)
词嵌入 (tok_emb)
假设学习到的嵌入向量:
猫 → [0.1, 0.2, 0.3, 0.4]
吃 → [0.5, 0.6, 0.7, 0.8]
鱼 → [0.9, 1.0, 1.1, 1.2]
→ 输出 tok_emb(x):
位置编码 (pos_emb)
用正弦公式计算位置向量(简化值):
位置0(猫):[0.0, 1.0, 0.0, 1.0]
位置1(吃):[0.8, 0.5, 0.01, 1.0]
位置2(鱼):[0.9, -0.4, 0.02, 1.0]
→ 输出 pos_emb(x):
output = dropout(tok_emb + pos_emb)
[[[0.1+0.0, 0.2+1.0, 0.3+0.0, 0.4+1.0],   # 猫 → [0.1, 1.2, 0.3, 1.4][0.5+0.8, 0.6+0.5, 0.7+0.01, 0.8+1.0],  # 吃 → [1.3, 1.1, 0.71, 1.8][0.9+0.9, 1.0-0.4, 1.1+0.02, 1.2+1.0]]] # 鱼 → [1.8, 0.6, 1.12, 2.2]Dropout:以10%概率随机置零部分值(例如可能变成):
[[[0.1, 1.2, 0.0, 1.4],   # 第2维被置零[1.3, 1.1, 0.71, 1.8],[1.8, 0.6, 1.12, 2.2]]]

最终输出形状:(batch_size, seq_len, d_model) = (1, 3, 4)

http://www.xdnf.cn/news/12842.html

相关文章:

  • 【数据结构与算法】从广度优先搜索到Dijkstra算法解决单源最短路问题
  • springboot3.5整合Spring Security6.5默认密码没有打印输出控制台排查过程
  • DeepSeek 终章:破局之路,未来已来
  • 图像超分辨率
  • 爱抚宠物小程序源代码+lw+ppt
  • 数据库学习(三)——MySQL锁
  • for循环应用
  • 【西门子杯工业嵌入式-6-ADC采样基础】
  • 详细叙述一下Spring如何创建bean
  • Python训练营打卡DAY48
  • 华为IP(8)(OSPF开放最短路径优先)
  • 树状数组学习笔记
  • 振动力学:无阻尼多自由度系统(受迫振动)
  • SQL进阶之旅 Day 21:临时表与内存表应用
  • Spring MVC请求处理流程和DispatcherServlet机制解析
  • 【Go语言基础【18】】Map基础
  • 2025-04-28-堆、栈及其应用分析
  • 算法专题七:分治
  • 【CATIA的二次开发23】抽象对象Document涉及文档激活控制的方法
  • serv00 ssh登录保活脚本-邮件通知版
  • 【构建】CMake 常用函数和命令清单
  • leetcode189-轮转数组
  • Prefix Caching 详解:实现 KV Cache 的跨请求高效复用
  • c++对halcon的动态链接库dll封装及调用(细细讲)
  • 【CSS-8】深入理解CSS选择器权重:掌握样式优先级的关键
  • 【拆机系列】暴力拆解AOC E2270SWN6液晶显示屏
  • Python训练营打卡Day48(2025.6.8)
  • 【LangChain4J】LangChain4J 第三弹:多模态与文生图的实现
  • leetcode_56 合并区间
  • el-table的select回显问题