当前位置: 首页 > ds >正文

transformer 解码器和输出部分结构

解码器部分实现

目标

  • 了解解码器中各个组件的作用
  • 掌握解码器各个组成部分的实现过程

解码器拆解

  • 由N个解码器层堆叠而成
  • 每个解码器由三个子层连接结构组成
  • 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
  • 第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
  • 第三个子层连接结构包括一个潜亏全连接子层和规范化层以及一个残差连接

说明

  • 解码器层中的各个部分, 多头注意力, 规范化层, 前馈全连接网络, 子层连接结构和编码器中的实现相同, 可以直接拿来构建解码器

解码器部分

解码器层

目标

  • 了解解码器层的作用
  • 掌握解码器层的实现过程

作用
作为解码器的组成单元, 每个解码器层根据给定的输入向目标方向进行特征提取操作, 也称为解码过程

代码分析
import torch.nn as nnfrom transformer_test.attention import clones
from transformer_test.sub_layer_connection import SubLayerConnectionclass DecoderLayer(nn.Module):def __init__(self, size, self_attn, src_attn, feed_forward, dropout):""":param size: 词嵌入维度:param self_attn: 自注意力对象, Q=K=V:param src_attn: 多头注意力对象, Q!=K=V:param feed_forward: 前馈全连接对象:param dropout:"""super(DecoderLayer, self).__init__()self.self_attn = self_attnself.src_attn = src_attnself.feed_forward = feed_forwardself.sublayer = clones(SubLayerConnection(size, dropout), 3)self.size = sizedef forward(self, x, memory, src_mask, tgt_mask):""":param x: 上一层的输入x:param memory: 编码器层的语义存储变量:param src_mask: 源数据掩码张量:param tgt_mask: 目标数据掩码张量:return:"""m = memory# 将 x 传入第一个子层结构, 第一个子层结构的输入分别是 x, self-attn 函数, 因为是自注意力机制, 所以 Q=K=V# 最后一个参数(tgt_mask)是目标数据掩码张量, 这时要对目标数据进行遮掩, 因为此时模型可能还没有生成任何结果# 比如在解码器准备生成第一个字符或者词汇时, 我们其实已经传入了第一个字符以便计算损失# 但是我们不希望在生成第一个字符时候模型能利用这个信息, 因此我们会将其遮掩, 同样生成第二个字符, 模型只能使用第一个字符或者词汇信息, 第二个字符以及以后的信息都不允许被模型使用x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))# 第二个子层, 常规注意力机制, Q是输入的x, K, V是编码层输出的memory# 同样传入 src_mask, 对进行源数据遮掩的原因并非是抑制信息泄漏, 而是遮蔽掉对结果没有意义的字符而产生的注意力值# 以此提升模型效果和训练速度, 这样就完成了第二个子层的处理x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))# 最后一个子层就是前馈全连接子层, 经过它的处理就可以返回结果, 这就是解码器结构return self.sublayer[2](x, self.feed_forward)

解码器

目标

  • 了解解码器的作用
  • 掌握解码器的实现过程

作用
根据解码器的结果以及上一次预测的结果, 对下一次可能出现的’值’进行特征表示

解码器部分

代码分析
import torch.nn as nnfrom transformer_test.attention import clones
from transformer_test.layer_norm import LayerNormclass Decoder(nn.Module):def __init__(self, layer, N):super(Decoder, self).__init__()self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, memory, source_mark, target_mark):""":param x: 目标数据嵌入表示:param memory: 编码器层输出:param source_mark: 源数据掩码张量:param target_mark: 目标数据掩码张量:return:"""# 对每个层进行循环, 让变量 x 通过每一层处理for layer in self.layers:x = layer(x, memory, source_mark, target_mark)# 进行规范化返回return self.norm(x)

输出部分实现

目标

  • 了解线性层和softmax作用
  • 掌握线性层和softmax的实现过程

作用
通过对上一步的线性变化得到指定维度的输出, 也就是转换维度的作用

输出部分

代码分析
import torch.nn as nn
import torch.nn.functional as F# 将线性层和softmax计算层一起实现, 因为二者的共同目标是生成最后的结构
# 因此吧类的名字叫做Generator, 生成器类
class Generator(nn.Module):def __init__(self, d_model, vocab):"""初始化函数的输入参数有两个, d_model代表词嵌入维度, vocab_size代表词表大小"""super(Generator, self).__init__()# 首先使用线性层示例话, 得到project对象等待使用# 这个线性层的参数有两个, 就是初始化函数传进来的两个参数 d_model, vocab_sizeself.project = nn.Linear(d_model, vocab)def forward(self, x):"""向前逻辑函数中输入是上一层的输出张量x:param x::return:"""# 在函数中, 首先使用 project 对 x 进行线性变化, 然后使用 log_softmax 进行 softmax 处理, log_softmax 就是对 softmax 取对数, 因为对数函数是单调递增的, 对我们去最大值概率没有影响return F.log_softmax(self.project(x), dim=-1)
总结

线性层作用
对上一步的线性变化得到指定维度的输出, 也就是转换维度的作用

softmax层作用
使以后一维的向量中的数字缩放到0~1, 的概率值域内, 并满足他们的和为1

学习并实现了softmax层和Generator类

  • 初始化函数的参数有两个, d_model 代表词嵌入维度, vocab_size 代表词表大小
  • forward函数接受上一层的输出
  • 最终获得经过线性层和softmax层处理的结果
http://www.xdnf.cn/news/1330.html

相关文章:

  • gradle可用的下载地址(免费)
  • Linux 内核中 cgroup 子系统 cpuset 是什么?
  • nodejs模块暴露数据的方式,和引入(导入方式)方式
  • 高级java每日一道面试题-2025年4月21日-基础篇[反射篇]-如何使用反射获取一个类的所有方法?
  • 移动通信运营商对MTU的大小设置需求
  • 【codeforces思维题】前缀和的巧妙应用(2053B)
  • 【AI News | 20250422】每日AI进展
  • 计算机组成原理---总线系统的详细概述
  • HCIP-H12-821 核心知识梳理 (5)
  • 如何修改文件termsrv.dll实现多用户同时远程
  • 一个关于相对速度的假想的故事-4
  • AGI大模型(12):向量检索之关键字搜索
  • 企业战略到数字化落地 —— 第四章 SOP 的概念
  • 几种电气绝缘类型
  • Mininet--node.py源码解析
  • 学习笔记——《Java面向对象程序设计》-抽象和接口
  • 实验1python基本网络应用
  • 为TA开发人员介绍具有最新改进的Kinibi-610a
  • 【Vue3 / TypeScript】 项目兼容低版本浏览器的全面指南
  • 【MySQL】数据库基础
  • 从马拉松到格斗大赛:人形机器人撕开的万亿市场,正在改写AI规则
  • STM32单片机入门学习——第45节: [13-2] 修改频主睡眠模式停止模式待机模式
  • G1 人形机器人硬件构成与接口
  • 图像挖掘课程笔记-第一章:了解机器视觉
  • 【TeamFlow】4.3.2 细化时间单位
  • 设备预测性维护系统部署成本:技术架构与成本优化策略解析
  • Linux——基于socket编程实现简单的Tcp通信
  • Size of map written was 1, but number of entries written was 0. 异常分析
  • 进阶篇 第 7 篇 (终章):融会贯通 - 多变量、模型选择与未来之路
  • 数据可视化--数据探索性分析