当前位置: 首页 > backend >正文

《零基础入门AI:长短期记忆网络(LSTM)与门控循环单元(GRU)(原理、结构与实现)》

1. 引言

标准循环神经网络(RNN)通过隐藏状态 hth_tht 实现对序列历史信息的编码,其递归更新机制为:
ht=f(Whht−1+Wxxt+bh) h_t = f(W_h h_{t-1} + W_x x_t + b_h) ht=f(Whht1+Wxxt+bh)
然而,在反向传播过程中,梯度需沿时间维度传递。当序列过长时,梯度在链式法则下可能指数级衰减(消失)或增长(爆炸),导致模型难以学习远距离依赖。这一现象被称为梯度消失/爆炸问题(Vanishing/Exploding Gradient Problem)

为克服该问题,Hochreiter & Schmidhuber (1997) 提出长短期记忆网络(LSTM),随后 Cho et al. (2014) 提出简化版本——门控循环单元(GRU)。二者均引入门控机制(Gating Mechanism),通过可学习的门控变量动态调节信息流动,从而有效缓解长期依赖建模的困难。


2. 长短期记忆网络(LSTM)

2.1 核心思想

LSTM 的核心在于引入细胞状态(Cell State) ct∈Rdcc_t \in \mathbb{R}^{d_c}ctRdc 作为长期记忆载体,并通过三个门控单元——遗忘门(Forget Gate)输入门(Input Gate)输出门(Output Gate)——控制信息的写入、保留与读取。

2.2 数学形式化

在时间步 ttt,LSTM 执行以下操作:

(1) 遗忘门(Forget Gate)

决定从细胞状态 ct−1c_{t-1}ct1 中丢弃哪些信息:
ft=σ(Wf⋅[ht−1,xt]+bf) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf[ht1,xt]+bf)

(2) 输入门(Input Gate)

决定哪些新信息将被写入细胞状态:
it=σ(Wi⋅[ht−1,xt]+bi) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi[ht1,xt]+bi)

c~t=tanh⁡(Wc⋅[ht−1,xt]+bc) \tilde{c}_t = \tanh(W_c \cdot [h_{t-1}, x_t] + b_c) c~t=tanh(Wc[ht1,xt]+bc)

(3) 细胞状态更新

结合遗忘门与输入门,更新细胞状态:
ct=ft⊙ct−1+it⊙c~t c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t ct=ftct1+itc~t
其中 ⊙\odot 表示逐元素乘法(Hadamard product)。

(4) 输出门(Output Gate)

决定从细胞状态中输出哪些信息作为隐藏状态 hth_tht
ot=σ(Wo⋅[ht−1,xt]+bo) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo[ht1,xt]+bo)

ht=ot⊙tanh⁡(ct) h_t = o_t \odot \tanh(c_t) ht=ottanh(ct)

符号说明

  • W∗∈Rdc×(dh+dx)W_* \in \mathbb{R}^{d_c \times (d_h + d_x)}WRdc×(dh+dx):各门的权重矩阵;
  • b∗∈Rdcb_* \in \mathbb{R}^{d_c}bRdc:偏置向量;
  • σ(⋅)\sigma(\cdot)σ():Sigmoid 激活函数,输出范围 [0,1][0,1][0,1],用于门控;
  • [⋅,⋅][\cdot,\cdot][,]:向量拼接操作;
  • dcd_cdc:细胞状态维度(通常 dc=dhd_c = d_hdc=dh)。

2.3 结构优势

  • 长期记忆保持:细胞状态 ctc_tct 在无干扰情况下可近乎恒定地传递信息(当 ft≈1,it≈0f_t \approx 1, i_t \approx 0ft1,it0),梯度可沿此路径稳定传播。
  • 选择性更新:门控机制允许模型学习何时“遗忘”旧信息、何时“记忆”新信息。

3. 门控循环单元(GRU)

3.1 设计动机

GRU 是 LSTM 的简化变体,旨在以更少的参数和计算复杂度实现类似性能。其将 LSTM 的细胞状态与隐藏状态合并,并将三个门简化为两个门。

3.2 数学形式化

(1) 更新门(Update Gate)

决定前一状态 ht−1h_{t-1}ht1 与候选状态 h~t\tilde{h}_th~t 的混合比例:
zt=σ(Wz⋅[ht−1,xt]+bz) z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) zt=σ(Wz[ht1,xt]+bz)

(2) 重置门(Reset Gate)

决定对前一状态 ht−1h_{t-1}ht1 的“重置”程度,影响候选状态的计算:
rt=σ(Wr⋅[ht−1,xt]+br) r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) rt=σ(Wr[ht1,xt]+br)

(3) 候选隐藏状态

使用重置门调制前一状态:
h~t=tanh⁡(Wh⋅[rt⊙ht−1,xt]+bh) \tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) h~t=tanh(Wh[rtht1,xt]+bh)

(4) 最终隐藏状态

由更新门控制新旧状态的加权融合:
ht=(1−zt)⊙ht−1+zt⊙h~t h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t ht=(1zt)ht1+zth~t

符号说明

  • Wz,Wr,Wh∈Rdh×(dh+dx)W_z, W_r, W_h \in \mathbb{R}^{d_h \times (d_h + d_x)}Wz,Wr,WhRdh×(dh+dx):权重矩阵;
  • bz,br,bh∈Rdhb_z, b_r, b_h \in \mathbb{R}^{d_h}bz,br,bhRdh:偏置向量。

3.3 与LSTM的对比

特性LSTMGRU
状态变量细胞状态 ctc_tct + 隐藏状态 hth_tht单一隐藏状态 hth_tht
门控数量3(遗忘、输入、输出)2(更新、重置)
参数量较多较少
计算复杂度较高较低
性能表现通常略优接近LSTM

4. 实现方法(PyTorch)

4.1 使用高级API构建模型

PyTorch 提供 nn.LSTMnn.GRU 模块,封装了复杂计算。

import torch
import torch.nn as nnclass SequenceModel(nn.Module):def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int, model_type: str = 'LSTM', num_layers: int = 1):super(SequenceModel, self).__init__()assert model_type in ['LSTM', 'GRU'], "model_type must be 'LSTM' or 'GRU'"self.embedding = nn.Embedding(vocab_size, embed_dim)# 选择模型类型if model_type == 'LSTM':self.rnn = nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True,dropout=0.5 if num_layers > 1 else 0,bidirectional=False)self.fc = nn.Linear(hidden_dim, num_classes)else:  # GRUself.rnn = nn.GRU(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True,dropout=0.5 if num_layers > 1 else 0,bidirectional=False)self.fc = nn.Linear(hidden_dim, num_classes)self.dropout = nn.Dropout(0.5)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播。Args:x: 输入序列,形状 (batch_size, seq_len)Returns:logits: 分类logits,形状 (batch_size, num_classes)"""embedded = self.dropout(self.embedding(x))  # (B, T, E)rnn_out, _ = self.rnn(embedded)             # rnn_out: (B, T, H)# 取最后一个时间步输出(适用于分类任务)last_output = rnn_out[:, -1, :]             # (B, H)logits = self.fc(self.dropout(last_output)) # (B, C)return logits

4.2 训练注意事项

  • 梯度裁剪:仍建议使用 torch.nn.utils.clip_grad_norm_ 防止梯度爆炸。
  • 批量处理:对变长序列使用 pack_padded_sequencepad_packed_sequence 提高效率。
  • 初始化:隐藏状态通常由框架自动初始化为零,亦可使用可学习参数。

5. 应用场景与局限性

5.1 典型应用场景

  • 自然语言处理:机器翻译、文本生成、情感分析;
  • 语音识别:声学模型;
  • 时间序列预测:金融、气象、工业传感器数据;
  • 手写识别:在线手写轨迹建模。

5.2 局限性

  • 并行化困难:时间步的顺序依赖限制训练速度;
  • 超长序列建模:仍可能遗忘极早期信息;
  • 被Transformer超越:自注意力机制在多数任务中表现更优。

6. 结论

LSTM 与 GRU 作为门控循环网络的代表,通过引入可学习的门控机制,有效缓解了标准 RNN 的梯度消失问题,显著提升了对长期依赖的建模能力。LSTM 通过细胞状态与三门结构实现精细控制,而 GRU 以更简洁的设计实现相近性能。二者在深度学习发展史上具有里程碑意义,至今仍在特定任务中广泛应用。理解其内部机制不仅有助于模型调优,也为掌握更先进的序列模型(如Transformer)奠定坚实基础。


http://www.xdnf.cn/news/18793.html

相关文章:

  • 【大前端】实现一个前端埋点SDK,并封装成NPM包
  • 【机械故障】旋转机械故障引起的振动信号调制效应概述
  • 在线教育系统源码助力教培转型:知识付费平台开发的商业实践
  • 达索 Enovia 许可管理技术白皮书:机制解析与智能优化实践
  • 面试 总结(1)
  • 项目集升级:顶部导览优化、字段自定义、路线图双模式、阶段图掌控、甘特图升级、工作量优化、仪表盘权限清晰
  • 31.Encoder-Decoder(Seq2Seq)
  • Docker详细学习
  • 【Protues仿真】定时器
  • 构建智能提示词工程师:LangGraph 的自动化提示词生成流程
  • [在实践中学习] 中间件理论和方法--Redis
  • WPF基于LiveCharts2图形库,实现:折线图,柱状图,饼状图
  • Python爬虫实战:研究开源的高性能代理池,构建电商数据采集和分析系统
  • Pycharm
  • ​告别复杂计划!日事清推出脑图视图,支持节点拖拽与聚焦模式,让项目管理更直观​
  • MySQL 入门
  • 虚幻5引擎:我们是在创造世界,还是重新发现世界?
  • 基于SpringBoot的摄影跟拍约拍预约系统【2026最新】
  • [CS创世SD NAND征文] CS创世CSNP1GCR01-AOW在运动控制卡中的高可靠应用
  • 神经网络参数量计算详解
  • 如何用企业微信AI解决金融运维难题,让故障响应快、客服专业度高
  • EB_NXP_K3XX_GPIO配置使用
  • 深入理解内存屏障(Memory Barrier):现代多核编程的基石
  • Java大厂面试实战:从Spring Boot到微服务架构的全链路技术拆解
  • 破解VMware迁移难题的技术
  • 给高斯DB写一个函数实现oracle中GROUPING_ID函数的功能
  • 性能瓶颈定位更快更准:ARMS 持续剖析能力升级解析
  • Docker Compose 使用指南 - 1Panel 版
  • NR --PO计算
  • nginx代理 flink Dashboard、sentinel dashboard的问题