深入理解Transformer:编码器与解码器的核心原理与实现
1. 引言
Transformer 是 NLP 和 CV 领域的革命性架构,由 Google 在 2017 年提出(论文《Attention Is All You Need》)。它抛弃了传统的 RNN 和 CNN,完全依赖 自注意力(Self-Attention) 机制,大幅提升了模型并行化能力和长序列建模效果。
本文将深入解析 Transformer 的两大核心组件——编码器(Encoder) 和 解码器(Decoder),剖析它们的作用、内部结构及协作方式,并附上关键代码实现(基于 PyTorch)。
2. 编码器(Encoder)详解
2.1 编码器的核心任务
编码器的目标是将输入序列(如句子、图像块)转换为 富含上下文信息的向量表示,供解码器使用。例如,在机器翻译中,编码器将源语言句子编码为高维语义向量。
2.2 编码器的核心组件
(1) 输入嵌入层(Input Embedding)
将离散的输入符号(如单词)映射为连续的向量表示:
embedding = nn.Embedding(vocab_size, d_model) # d_model 是向量维度(如512)
x = embedding(input_tokens) # 输入形状: [batch_size, seq_len] → 输出形状: [batch_size, seq_len, d_model]
(2) 位置编码(Positional Encoding)
由于 Transformer 没有时序信息,需显式注入位置编码(PE):
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term) # 偶数维正弦pe[:, 1::2] = torch.cos(position * div_term) # 奇数维余弦self.register_buffer('pe', pe) # 不参与训练def forward(self, x):return x + self.pe[:x.size(1)] # 叠加位置信息
(3) 多头自注意力(Multi-Head Self-Attention)
计算序列内元素的依赖关系,公式:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
多头机制将注意力拆分为 hhh 个并行头(如 h=8h=8h=8),增强模型表达能力:
# PyTorch 实现(简化版)
multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=0.1)
attn_output, _ = multihead_attn(query, key, value, attn_mask=None)
(4) 前馈神经网络(FFN)
对每个位置的向量进行非线性变换:
ffn = nn.Sequential(nn.Linear(d_model, d_ff), # d_ff 通常为 2048nn.ReLU(),nn.Linear(d_ff, d_model)
(5) 残差连接与层归一化
每个子层后应用:
x = x + dropout(sublayer(x)) # 残差连接
x = LayerNorm(x) # 层归一化
(6) 堆叠 N 层
编码器通常由多个相同层堆叠而成(如 N=6):
class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout):super().__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.ffn = PositionwiseFFN(d_model, d_ff)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):# 自注意力子层attn_output = self.self_attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))# FFN 子层ffn_output = self.ffn(x)x = self.norm2(x + self.dropout(ffn_output))return x
3. 解码器(Decoder)详解
3.1 解码器的核心任务
解码器基于编码器的输出,自回归生成目标序列(如翻译结果、生成文本)。其关键特点是:
- 掩码自注意力:防止解码时“偷看”未来信息。
- 编码器-解码器注意力:建立源序列与目标序列的关联。
3.2 解码器的核心组件
(1) 掩码自注意力(Masked Self-Attention)
通过掩码(上三角矩阵)遮挡未来位置:
# 生成掩码(seq_len=5)
mask = torch.triu(torch.ones(5, 5), diagonal=1).bool()
# 输出:
# [[False, True, True, True, True],
# [False, False, True, True, True],
# ...]
(2) 编码器-解码器注意力(Cross-Attention)
- Query 来自解码器,Key/Value 来自编码器:
cross_attn = nn.MultiheadAttention(d_model, num_heads)
attn_output, _ = cross_attn(query=decoder_output, # 解码器当前输出key=encoder_output, # 编码器的输出value=encoder_output)
(3) 输出层(Linear + Softmax)
将解码结果映射到词汇表空间:
linear = nn.Linear(d_model, tgt_vocab_size)
logits = linear(decoder_output) # 形状: [batch_size, seq_len, tgt_vocab_size]
probs = F.softmax(logits, dim=-1)
4. 编码器 vs 解码器:核心区别对比
以机器翻译为例(英文→中文):
- 编码阶段:编码器处理
"I love cats"
,输出上下文向量。 - 解码阶段:
- 第1步:输入
<START>
,预测"我"
; - 第2步:输入
<START> 我
,预测"爱"
; - 第3步:输入
<START> 我爱
,预测"猫"
; - 直到生成
<END>
。
- 第1步:输入
为了更清晰地理解Transformer中编码器和解码器的差异,我们从注意力机制、输入依赖、掩码需求和典型应用四个维度进行对比:
特性 | 编码器 (Encoder) | 解码器 (Decoder) |
---|---|---|
注意力类型 | 自注意力(Self-Attention) 全局无掩码 | 1. 掩码自注意力(Masked Self-Attention) 2. 编码器-解码器注意力(Cross-Attention) |
输入依赖 | 仅处理输入序列(如源语言句子) | 依赖两部分输入: - 编码器的输出(源序列编码) - 已生成的目标序列(自回归) |
是否需要掩码 | 否(可看到完整输入序列) | 是(掩码未来位置,防止作弊) |
典型应用场景 | - BERT(仅编码器) - 文本分类 - 特征提取 | - GPT(仅解码器) - 机器翻译 - 文本生成 |
关键区别解析
1. 注意力机制不同
- 编码器:仅使用自注意力,计算输入序列内部所有元素的关系(例如分析句子中单词之间的依赖)。
# 编码器自注意力(无掩码) self_attn = MultiHeadAttention(d_model, num_heads) output = self_attn(query=x, key=x, value=x) # Q=K=V
- 解码器:
- 掩码自注意力:防止解码时看到未来信息(如生成第3个词时,只能看前2个词)。
# 解码器掩码自注意力 mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() output = self_attn(query=x, key=x, value=x, attn_mask=mask)
- 编码器-解码器注意力:建立源序列和目标序列的关联(如翻译中对齐的单词)。
# 编码器-解码器注意力 cross_attn = MultiHeadAttention(d_model, num_heads) output = cross_attn(query=decoder_x, key=encoder_x, value=encoder_x) # Q来自解码器,K/V来自编码器
- 掩码自注意力:防止解码时看到未来信息(如生成第3个词时,只能看前2个词)。
2. 输入依赖不同
- 编码器:仅需输入序列(例如待翻译的英文句子)。
- 解码器:
- 训练时:接收完整目标序列(但通过掩码隐藏未来位置)。
- 推理时:逐步生成(每次输入已生成的部分序列,如
<START> → <START>我 → <START>我爱
)。
3. 掩码的作用
- 编码器无需掩码:可访问整个输入序列的所有位置。
- 解码器必须掩码:确保生成时只能看到“过去”信息,避免数据泄露。
4. 典型模型
- 仅编码器模型(如BERT):适合理解任务(分类、实体识别)。
- 仅解码器模型(如GPT):适合生成任务(文本续写、对话)。
- 编码器-解码器模型(如T5、BART):适合序列到序列任务(翻译、摘要)。
为什么这些区别重要?
-
并行化能力:
- 编码器可并行处理整个输入序列。
- 解码器在训练时通过掩码模拟逐步生成,实现并行;但推理时必须串行。
-
信息流控制:
- 编码器专注于理解输入。
- 解码器需兼顾输入和已生成内容,动态调整输出。
-
应用场景分离:
- 编码器更适合特征提取(如BERT提取句子向量)。
- 解码器更适合创造性任务(如GPT-3生成故事)。
5. 关键问题解答
Q1:为什么需要残差连接和层归一化?
- 残差连接缓解梯度消失,层归一化稳定训练。
Q2:自注意力和交叉注意力的区别?
- 自注意力:序列内部关系(Query=Key=Value);
- 交叉注意力:连接编码器和解码器(Query来自解码器,Key/Value来自编码器)。
Q3:Transformer 如何实现并行化?
- 编码器可并行处理整个输入序列;
- 解码器在训练时可通过掩码实现并行,推理时需逐步生成。
6. 总结
- 编码器:提取输入序列的全局特征,核心是自注意力和FFN。
- 解码器:自回归生成目标序列,依赖掩码注意力和交叉注意力。
Transformer 的灵活设计使其成为 NLP 和 CV 的基石架构(如 BERT、GPT、ViT)。理解其编码器-解码器机制是掌握现代深度学习模型的关键!
代码完整实现:可参考 HuggingFace Transformers 或 PyTorch 官方教程。