《零基础入门AI:长短期记忆网络(LSTM)与门控循环单元(GRU)(原理、结构与实现)》
1. 引言
标准循环神经网络(RNN)通过隐藏状态 hth_tht 实现对序列历史信息的编码,其递归更新机制为:
ht=f(Whht−1+Wxxt+bh)
h_t = f(W_h h_{t-1} + W_x x_t + b_h)
ht=f(Whht−1+Wxxt+bh)
然而,在反向传播过程中,梯度需沿时间维度传递。当序列过长时,梯度在链式法则下可能指数级衰减(消失)或增长(爆炸),导致模型难以学习远距离依赖。这一现象被称为梯度消失/爆炸问题(Vanishing/Exploding Gradient Problem)。
为克服该问题,Hochreiter & Schmidhuber (1997) 提出长短期记忆网络(LSTM),随后 Cho et al. (2014) 提出简化版本——门控循环单元(GRU)。二者均引入门控机制(Gating Mechanism),通过可学习的门控变量动态调节信息流动,从而有效缓解长期依赖建模的困难。
2. 长短期记忆网络(LSTM)
2.1 核心思想
LSTM 的核心在于引入细胞状态(Cell State) ct∈Rdcc_t \in \mathbb{R}^{d_c}ct∈Rdc 作为长期记忆载体,并通过三个门控单元——遗忘门(Forget Gate)、输入门(Input Gate) 和 输出门(Output Gate)——控制信息的写入、保留与读取。
2.2 数学形式化
在时间步 ttt,LSTM 执行以下操作:
(1) 遗忘门(Forget Gate)
决定从细胞状态 ct−1c_{t-1}ct−1 中丢弃哪些信息:
ft=σ(Wf⋅[ht−1,xt]+bf)
f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
ft=σ(Wf⋅[ht−1,xt]+bf)
(2) 输入门(Input Gate)
决定哪些新信息将被写入细胞状态:
it=σ(Wi⋅[ht−1,xt]+bi)
i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)
it=σ(Wi⋅[ht−1,xt]+bi)
c~t=tanh(Wc⋅[ht−1,xt]+bc) \tilde{c}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c) c~t=tanh(Wc⋅[ht−1,xt]+bc)
(3) 细胞状态更新
结合遗忘门与输入门,更新细胞状态:
ct=ft⊙ct−1+it⊙c~t
c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t
ct=ft⊙ct−1+it⊙c~t
其中 ⊙\odot⊙ 表示逐元素乘法(Hadamard product)。
(4) 输出门(Output Gate)
决定从细胞状态中输出哪些信息作为隐藏状态 hth_tht:
ot=σ(Wo⋅[ht−1,xt]+bo)
o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)
ot=σ(Wo⋅[ht−1,xt]+bo)
ht=ot⊙tanh(ct) h_t = o_t \odot \tanh(c_t) ht=ot⊙tanh(ct)
符号说明:
- W∗∈Rdc×(dh+dx)W_* \in \mathbb{R}^{d_c \times (d_h + d_x)}W∗∈Rdc×(dh+dx):各门的权重矩阵;
- b∗∈Rdcb_* \in \mathbb{R}^{d_c}b∗∈Rdc:偏置向量;
- σ(⋅)\sigma(\cdot)σ(⋅):Sigmoid 激活函数,输出范围 [0,1][0,1][0,1],用于门控;
- [⋅,⋅][\cdot,\cdot][⋅,⋅]:向量拼接操作;
- dcd_cdc:细胞状态维度(通常 dc=dhd_c = d_hdc=dh)。
2.3 结构优势
- 长期记忆保持:细胞状态 ctc_tct 在无干扰情况下可近乎恒定地传递信息(当 ft≈1,it≈0f_t \approx 1, i_t \approx 0ft≈1,it≈0),梯度可沿此路径稳定传播。
- 选择性更新:门控机制允许模型学习何时“遗忘”旧信息、何时“记忆”新信息。
3. 门控循环单元(GRU)
3.1 设计动机
GRU 是 LSTM 的简化变体,旨在以更少的参数和计算复杂度实现类似性能。其将 LSTM 的细胞状态与隐藏状态合并,并将三个门简化为两个门。
3.2 数学形式化
(1) 更新门(Update Gate)
决定前一状态 ht−1h_{t-1}ht−1 与候选状态 h~t\tilde{h}_th~t 的混合比例:
zt=σ(Wz⋅[ht−1,xt]+bz)
z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z)
zt=σ(Wz⋅[ht−1,xt]+bz)
(2) 重置门(Reset Gate)
决定对前一状态 ht−1h_{t-1}ht−1 的“重置”程度,影响候选状态的计算:
rt=σ(Wr⋅[ht−1,xt]+br)
r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r)
rt=σ(Wr⋅[ht−1,xt]+br)
(3) 候选隐藏状态
使用重置门调制前一状态:
h~t=tanh(Wh⋅[rt⊙ht−1,xt]+bh)
\tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h)
h~t=tanh(Wh⋅[rt⊙ht−1,xt]+bh)
(4) 最终隐藏状态
由更新门控制新旧状态的加权融合:
ht=(1−zt)⊙ht−1+zt⊙h~t
h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
ht=(1−zt)⊙ht−1+zt⊙h~t
符号说明:
- Wz,Wr,Wh∈Rdh×(dh+dx)W_z, W_r, W_h \in \mathbb{R}^{d_h \times (d_h + d_x)}Wz,Wr,Wh∈Rdh×(dh+dx):权重矩阵;
- bz,br,bh∈Rdhb_z, b_r, b_h \in \mathbb{R}^{d_h}bz,br,bh∈Rdh:偏置向量。
3.3 与LSTM的对比
特性 | LSTM | GRU |
---|---|---|
状态变量 | 细胞状态 ctc_tct + 隐藏状态 hth_tht | 单一隐藏状态 hth_tht |
门控数量 | 3(遗忘、输入、输出) | 2(更新、重置) |
参数量 | 较多 | 较少 |
计算复杂度 | 较高 | 较低 |
性能表现 | 通常略优 | 接近LSTM |
4. 实现方法(PyTorch)
4.1 使用高级API构建模型
PyTorch 提供 nn.LSTM
和 nn.GRU
模块,封装了复杂计算。
import torch
import torch.nn as nnclass SequenceModel(nn.Module):def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int, model_type: str = 'LSTM', num_layers: int = 1):super(SequenceModel, self).__init__()assert model_type in ['LSTM', 'GRU'], "model_type must be 'LSTM' or 'GRU'"self.embedding = nn.Embedding(vocab_size, embed_dim)# 选择模型类型if model_type == 'LSTM':self.rnn = nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True,dropout=0.5 if num_layers > 1 else 0,bidirectional=False)self.fc = nn.Linear(hidden_dim, num_classes)else: # GRUself.rnn = nn.GRU(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True,dropout=0.5 if num_layers > 1 else 0,bidirectional=False)self.fc = nn.Linear(hidden_dim, num_classes)self.dropout = nn.Dropout(0.5)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播。Args:x: 输入序列,形状 (batch_size, seq_len)Returns:logits: 分类logits,形状 (batch_size, num_classes)"""embedded = self.dropout(self.embedding(x)) # (B, T, E)rnn_out, _ = self.rnn(embedded) # rnn_out: (B, T, H)# 取最后一个时间步输出(适用于分类任务)last_output = rnn_out[:, -1, :] # (B, H)logits = self.fc(self.dropout(last_output)) # (B, C)return logits
4.2 训练注意事项
- 梯度裁剪:仍建议使用
torch.nn.utils.clip_grad_norm_
防止梯度爆炸。 - 批量处理:对变长序列使用
pack_padded_sequence
和pad_packed_sequence
提高效率。 - 初始化:隐藏状态通常由框架自动初始化为零,亦可使用可学习参数。
5. 应用场景与局限性
5.1 典型应用场景
- 自然语言处理:机器翻译、文本生成、情感分析;
- 语音识别:声学模型;
- 时间序列预测:金融、气象、工业传感器数据;
- 手写识别:在线手写轨迹建模。
5.2 局限性
- 并行化困难:时间步的顺序依赖限制训练速度;
- 超长序列建模:仍可能遗忘极早期信息;
- 被Transformer超越:自注意力机制在多数任务中表现更优。
6. 结论
LSTM 与 GRU 作为门控循环网络的代表,通过引入可学习的门控机制,有效缓解了标准 RNN 的梯度消失问题,显著提升了对长期依赖的建模能力。LSTM 通过细胞状态与三门结构实现精细控制,而 GRU 以更简洁的设计实现相近性能。二者在深度学习发展史上具有里程碑意义,至今仍在特定任务中广泛应用。理解其内部机制不仅有助于模型调优,也为掌握更先进的序列模型(如Transformer)奠定坚实基础。