当前位置: 首页 > ops >正文

Transformer架构三大核心:位置编码(PE)、前馈网络(FFN)和多头注意力(MHA)。

本文将用“直觉理解” -> “一图看懂” -> “代码实现”三步法,帮你无痛掌握Transformer的三大核心:位置编码(PE)前馈网络(FFN)多头注意力(MHA)
在这里插入图片描述


1. 位置编码 (Positional Encoding, PE)

🧠 直觉理解:

Transformer像一位“眼观六路”的读者,能同时看光所有词,但这就失去了词的顺序信息。PE的作用就是给每个词加上一个“数字坐标”,告诉模型每个词的位置。

为什么用正弦函数? 因为它有一个神奇的特性:模型可以轻松学会“位置5的词”和“位置3的词”之间的相对距离关系,这让它能更好地处理训练时没见过的长句子。

📊 一图看懂:

输入词序列
(例如: 'I', 'love', 'you')
词嵌入
将每个词变为一个数字向量
生成位置编码
根据公式计算每个位置的固定向量
词向量矩阵
(Seq_Len, Dim)
对应位置相加
输出带位置的词向量
(模型真正的输入)

⌨️ 代码实现 (PyTorch):

import torch
import torch.nn as nn
import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)# 核心计算公式:分母部分div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置用sinpe[:, 1::2] = torch.cos(position * div_term) # 奇数位置用cosself.register_buffer('pe', pe.unsqueeze(0)) # 注册为不参与训练的缓冲区def forward(self, x):return x + self.pe[:, :x.size(1)] # 将PE切片后直接加到词嵌入上# 使用示例
d_model = 512
seq_len = 50
x = torch.randn(1, seq_len, d_model) # 模拟一个批次的词嵌入
pe = PositionalEncoding(d_model)
output = pe(x)
print(f"加上位置编码后的形状: {output.shape}")

2. 前馈神经网络 (Feed-Forward Network, FFN)

🧠 直觉理解:

FFN是每个位置上的“私人专家”。它对序列中每个词向量进行独立、相同的处理。其设计妙处在于“先扩维再缩回”(例如512维 -> 2048维 -> 512维),这相当于给了模型一个更宽敞的“工作台”,让它能在高维空间里更自由地组合和提炼特征,然后再将结果投影回原来的维度。

📊 一图看懂:

输入向量
线性层 (扩大)
e.g., Dim 512 -> 2048
GELU激活函数
引入非线性
线性层 (缩小)
e.g., 2048 -> Dim 512
Dropout
随机失活,防止过拟合
输出向量

⌨️ 代码实现 (PyTorch):

class FeedForward(nn.Module):def __init__(self, d_model, d_ff=2048, dropout=0.1):super().__init__()self.net = nn.Sequential(nn.Linear(d_model, d_ff),  # 扩维nn.GELU(),nn.Linear(d_ff, d_model),  # 缩回nn.Dropout(dropout),)def forward(self, x):return self.net(x) # 输入输出形状均为 (batch_size, seq_len, d_model)# 使用示例
ffn = FeedForward(d_model)
ffn_output = ffn(output) # 输入是PE的输出
print(f"FFN输入输出形状: {ffn_output.shape}")

3. 多头注意力 (Multi-Head Attention, MHA)

🧠 直觉理解:

核心:让模型同时从不同角度审视信息。
单头注意力好比你看一篇文章只关注一种关系(例如“谁做了什么”)。而多头注意力让你像有多双眼睛同时关注不同的方面(一双眼看“语法主谓”,一双眼看“情感色彩”,一双眼看“指代关系”),最后把所有观察结果综合起来,得到更全面的理解。

📊 一图看懂:

单个注意力头计算
• V
Q • K^T
Scale
/ sqrt(d_k)
Softmax
输入: Q, K, V
线性投影并分割成H个头
拼接H个头的输出
线性投影融合信息
最终输出

⌨️ 代码实现 (PyTorch):

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, dropout=0.1):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 定义投影矩阵self.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(dropout)def forward(self, q, k, v, mask=None):batch_size = q.size(0)# 1. 线性投影并分头# 输入: (B, S, D) -> 投影: (B, S, D) -> 分头+转置: (B, H, S, d_k)q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)k = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)v = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)# 2. 计算缩放点积注意力# 公式: Attention(Q, K, V) = softmax(Q•K^T / sqrt(d_k)) • Vattn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)attn_weights = torch.softmax(attn_scores, dim=-1)attn_weights = self.dropout(attn_weights)attn_output = torch.matmul(attn_weights, v) # (B, H, S, d_k)# 3. 拼接所有头# 转置: (B, H, S, d_k) -> (B, S, H, d_k) -> 拼接: (B, S, D)attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)# 4. 最终线性投影return self.w_o(attn_output) # (B, S, D)# 使用示例 (自注意力)
num_heads = 8
mha = MultiHeadAttention(d_model, num_heads)
attn_output = mha(ffn_output, ffn_output, ffn_output) # Q, K, V 均来自FFN输出
print(f"MHA输出形状: {attn_output.shape}")

🧩 如何串联:一个迷你Transformer块

📊 一图看懂:

理解了组件,再看它们如何协作就非常容易了,流程图如下:

子层 2: 前馈网络
通过FFN: Linear → ReLU → Linear
复制 norm1_output 作为残差连接备用
Dropout
与子层输出相加: norm1_output + dropout(ffn_output)
层归一化
子层 1: 多头自注意力
计算自注意力: MHA(x, x, x)
复制 x 作为残差连接备用
Dropout
与原始输入相加: x + dropout(attn_output)
层归一化
输入: 张量 x
(batch_size, seq_len, d_model)
中间输出: norm1_output
输出: 变换后的张量
(batch_size, seq_len, d_model)

⌨️ 代码实现 (PyTorch):

class TransformerBlock(nn.Module):"""一个完整的Transformer编码层"""def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super().__init__()self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)self.ffn = FeedForward(d_model, d_ff, dropout)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):# 子层1: 自注意力 + 残差连接 & 层归一化attn_output = self.self_attn(x, x, x) # 自注意力x = x + self.dropout(attn_output)     # 残差连接x = self.norm1(x)                     # 层归一化# 子层2: 前馈网络 + 残差连接 & 层归一化ffn_output = self.ffn(x)x = x + self.dropout(ffn_output)x = self.norm2(x)return x# 最终串联
mini_transformer = TransformerBlock(d_model, num_heads, d_ff=2048)
final_output = mini_transformer(pe(x)) # 从带位置编码的输入开始
print(f"迷你Transformer块输出形状: {final_output.shape}")

希望这个重新构思的回答更加清晰和深入!它现在包含了更多的设计洞见和更流畅的代码逻辑。

http://www.xdnf.cn/news/19269.html

相关文章:

  • 学习Python中Selenium模块的基本用法(12:操作Cookie)
  • TFS-2005《A Possibilistic Fuzzy c-Means Clustering Algorithm》
  • 使用 Python 自动化检查矢量面数据的拓扑错误(含导出/删除选项)
  • 算法题(196):最大异或对
  • 特殊符号在Html中的代码及常用标签格式的记录
  • Qt组件布局的经验
  • 线程池、锁策略
  • 机器视觉opencv教程(四):图像颜色识别与颜色替换
  • Linux中的ss命令
  • kotlin - 2个Activity实现平行视图,使用SplitPairFilter
  • 网络流量分析——使用Wireshark进行分析
  • Shell脚本编程——变量用法详解
  • Ruoyi-vue-plus-5.x第二篇MyBatis-Plus数据持久层技术:2.2 分页与性能优化
  • DAY17-新世纪DL(DeepLearning/深度学习)战士:Q(机器学习策略)2
  • AI 应用 图文 解说 (二) -- 百度智能云 ASR LIM TTS 语音AI助手源码
  • 自定义AXI_PWM_v1.0——ZYNQ学习笔记15
  • Spring Task快速上手
  • Maven学习笔记01
  • 【stm32】对射式红外传感器计次以及旋转编码器计次
  • SpringBoot 自研运行时 SQL 调用树,3 分钟定位慢 SQL!
  • 用产品经理的思维,重构AI时代的增长Playbook
  • 企业数据湖:从混沌到秩序的分层设计与治理策略
  • 11.1.5 实现文件删除,共享和共享下载排行榜
  • 分布式测试平台ITP:让自动化测试更高效、更稳定
  • SW - 用装配图的方式组合多个子零件然后转换成为零件,可维护性好
  • 组件通信终极指南:从Props Drilling到Context API
  • react-virtualized React 应用中高效渲染大型列表和表格数据的库
  • 扣子(coze)实践指南进阶篇——创建工作流,并将工作流接入智能体
  • 2025年8月个人工作生活总结
  • [Windows] 某音下载工具——自用