当前位置: 首页 > ai >正文

《零基础入门AI:循环神经网络(Recurrent Neural Networks)(从原理到实现)》

1. 引言

在处理如文本、语音、时间序列等序列数据时,传统神经网络(如全连接网络或卷积网络)面临根本性挑战:它们假设输入样本之间相互独立,无法有效捕捉序列中元素间的时序依赖关系。例如,在自然语言中,“人工智能”一词的含义依赖于前后词汇的语境。为解决这一问题,循环神经网络(Recurrent Neural Network, RNN) 被提出,其通过引入隐藏状态(hidden state) 作为“记忆”机制,显式地编码序列的历史信息,从而实现对序列动态特性的建模。


2. 循环神经网络的形式化定义

2.1 基本结构

设输入序列为 x=(x1,x2,…,xT)\mathbf{x} = (x_1, x_2, \dots, x_T)x=(x1,x2,,xT),其中 xt∈Rdxx_t \in \mathbb{R}^{d_x}xtRdx 为第 ttt 个时间步的输入向量,TTT 为序列长度。RNN 在每个时间步 ttt 维护一个隐藏状态 ht∈Rdhh_t \in \mathbb{R}^{d_h}htRdh,其更新遵循以下递归公式:

ht=f(Whht−1+Wxxt+bh) h_t = f(W_h h_{t-1} + W_x x_t + b_h) ht=f(Whht1+Wxxt+bh)

其中:

  • Wh∈Rdh×dhW_h \in \mathbb{R}^{d_h \times d_h}WhRdh×dh 为隐藏状态到隐藏状态的权重矩阵;
  • Wx∈Rdh×dxW_x \in \mathbb{R}^{d_h \times d_x}WxRdh×dx 为输入到隐藏状态的权重矩阵;
  • bh∈Rdhb_h \in \mathbb{R}^{d_h}bhRdh 为偏置向量;
  • f(⋅)f(\cdot)f() 为非线性激活函数,通常为 tanh⁡\tanhtanhReLU\mathrm{ReLU}ReLU
  • 初始状态 h0h_0h0 通常初始化为零向量或可学习参数。

在时间步 ttt 的输出 yty_tyt 可通过输出层函数 g(⋅)g(\cdot)g() 生成:

yt=g(Wyht+by) y_t = g(W_y h_t + b_y) yt=g(Wyht+by)

其中 Wy∈Rdy×dhW_y \in \mathbb{R}^{d_y \times d_h}WyRdy×dh 为输出权重矩阵,by∈Rdyb_y \in \mathbb{R}^{d_y}byRdy 为输出偏置。

2.2 网络展开(Unrolling)

RNN 可视为在时间维度上展开的深度网络。在训练过程中,RNN 被“展开”为 TTT 个时间步的链式结构,每个时间步共享参数 Wh,Wx,bhW_h, W_x, b_hWh,Wx,bh。这种参数共享机制使得模型能够处理变长序列,并显著减少参数总量。


3. 为何使用RNN?—— 应用动机

3.1 处理序列依赖性

序列数据的本质特征在于其元素间存在时序依赖。例如,在自然语言中,一个词的语义往往依赖于其上下文。RNN 通过隐藏状态 hth_tht 显式地编码从 x1x_1x1xtx_txt 的历史信息,使模型能够捕捉此类依赖关系。

3.2 参数共享与变长输入

与卷积神经网络(CNN)在空间维度上的参数共享类似,RNN 在时间维度上共享参数。这一特性使模型具备处理任意长度输入序列的能力,而无需为不同长度序列设计独立模型。

3.3 通用序列建模能力

RNN 可应用于多种序列到序列(sequence-to-sequence)映射任务,包括但不限于:

  • 序列标注:如命名实体识别(NER)、词性标注;
  • 序列分类:如情感分析、文档分类;
  • 序列生成:如机器翻译、文本摘要、语音合成;
  • 时间序列预测:如股票价格预测、气象预报。

4. RNN的实现

4.1 (引入)举个具体例子

我们用一个简单句子演示:

“我 爱 学习”

第一步:每个词变成向量(词向量)

向量(简化版)
[1.0, 0.2]
[0.8, -0.3]
学习[-0.5, 0.9]

这些向量是“词的意思”的数字表示。


第二步:RNN 一个词一个词地读

时间步 1:读“我”

  • 输入:x = [1.0, 0.2]
  • 之前没有记忆,所以 h₀ = [0, 0](初始状态)
  • RNN 计算:根据“我” + “之前的记忆”,更新记忆
  • 得到新的记忆:h₁ = [0.6, 0.1]

RNN 心里想:“现在提到‘我’,可能是第一人称,开始说话了。”


时间步 2:读“爱”

  • 输入:x = [0.8, -0.3]
  • 之前的记忆:h₁ = [0.6, 0.1]
  • RNN 计算:把“爱”和“前面说‘我’”结合起来
  • 得到新记忆:h₂ = [0.7, -0.2]

RNN 心里想:“我 + 爱 → 这个人喜欢什么?情绪是正面的。”


时间步 3:读“学习”

  • 输入:x = [-0.5, 0.9]
  • 之前的记忆:h₂ = [0.7, -0.2]
  • RNN 计算:把“学习”和“我喜欢”结合起来
  • 得到新记忆:h₃ = [0.3, 0.5]

RNN 心里想:“这个人喜欢学习 → 整体是积极向上的句子。”


所以,“记忆”到底是啥?

问题回答
是什么?一个数字列表(向量),比如 [0.3, 0.5]
存在哪?存在一个变量 h 里,每一步都会更新它
怎么实现?每次把“新词”和“旧记忆”输入一个神经网络,输出“新记忆”
有什么用?后面如果出现“它”“他”“这个”,RNN 可以用这个记忆去理解

“记忆”不是记住每个词,而是不断总结当前的理解

4.2 基础RNN单元的实现

以下以PyTorch为例,展示基础RNN单元的实现,为理解高级API提供理论支撑。

import torch
import torch.nn as nnclass RNNCell(nn.Module):def __init__(self, input_size: int, hidden_size: int, activation: str = 'tanh'):super(RNNCell, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.activation = torch.tanh if activation == 'tanh' else torch.relu# 定义权重矩阵self.weight_ih = nn.Linear(input_size, hidden_size, bias=True)   # W_x x_tself.weight_hh = nn.Linear(hidden_size, hidden_size, bias=True)  # W_h h_{t-1}def forward(self, x_t: torch.Tensor, h_prev: torch.Tensor) -> torch.Tensor:"""前向传播:计算单个时间步的隐藏状态。Args:x_t: 当前输入,形状 (batch_size, input_size)h_prev: 上一时刻隐藏状态,形状 (batch_size, hidden_size)Returns:h_t: 当前隐藏状态,形状 (batch_size, hidden_size)"""h_t = self.activation(self.weight_ih(x_t) + self.weight_hh(h_prev))return h_t

4.3 使用高级API实现完整RNN模型

现代深度学习框架(如PyTorch、TensorFlow)提供高级RNN模块,简化实现过程。以下展示一个完整的序列分类模型。

class RNNClassifier(nn.Module):def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int, num_layers: int = 1, bidirectional: bool = False):super(RNNClassifier, self).__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.rnn = nn.RNN(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True,bidirectional=bidirectional,dropout=0.5 if num_layers > 1 else 0)self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), num_classes)self.dropout = nn.Dropout(0.5)def forward(self, x: torch.Tensor) -> torch.Tensor:"""Args:x: 输入序列,形状 (batch_size, seq_len)Returns:logits: 分类logits,形状 (batch_size, num_classes)"""embedded = self.dropout(self.embedding(x))  # (B, T, E)rnn_out, hidden = self.rnn(embedded)        # rnn_out: (B, T, H)# 使用最后一个时间步的输出进行分类last_output = rnn_out[:, -1, :]             # (B, H)logits = self.fc(self.dropout(last_output)) # (B, C)return logits

4.4 训练与优化

  • 梯度裁剪(Gradient Clipping):RNN训练中易出现梯度爆炸,建议使用 torch.nn.utils.clip_grad_norm_ 限制梯度范数。
  • 批量处理:使用 DataLoaderPaddedSequence 处理变长序列。
  • 设备管理:确保模型和数据在相同设备(CPU/GPU)上。
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()for batch in dataloader:optimizer.zero_grad()logits = model(batch.text)loss = criterion(logits, batch.label)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()

5. 局限性与演进

尽管RNN在序列建模中取得显著成果,其仍存在以下局限:

  • 梯度消失/爆炸问题:在长序列中,梯度在反向传播过程中可能指数级衰减或增长,导致难以学习长期依赖;
  • 训练效率低:由于时间步的顺序依赖,难以并行化训练。

为克服上述问题,研究者提出长短期记忆网络(LSTM)门控循环单元(GRU),通过引入门控机制有效缓解梯度问题。近年来,基于自注意力机制的Transformer模型在多数任务中超越RNN,成为序列建模的新范式。


6. 和 CBOW 的本质区别

对比点CBOWRNN
输入方式同时输入多个词(“我”“学习”)一个一个词输入(“我”→“爱”→“学”)
是否有序不关心顺序(“我学习爱”也能训练)非常关心顺序(“我爱学习” ≠ “学习爱我”)
输出一个词(中心词)一个概率分布(下一个词)
用途得到词向量做语言模型、生成文本、分类等

7. 结论

循环神经网络作为处理序列数据的基础架构,通过其递归结构实现了对时序依赖的有效建模。其参数共享机制与变长输入处理能力使其在自然语言处理、语音识别等领域具有重要应用价值。尽管存在梯度问题等局限,RNN的理论框架为后续LSTM、GRU及现代序列模型的发展奠定了坚实基础。掌握RNN的原理与实现方法,对于理解深度学习在序列任务中的演进路径具有重要意义。


参考文献

  1. Rumelhart, D. E., Hinton, G. E., & Williams, R. J. (1986). Learning representations by back-propagating errors. Nature, 323(6088), 533–536.
  2. Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735–1780.
  3. Cho, K., et al. (2014). Learning phrase representations using RNN encoder-decoder for statistical machine translation. EMNLP.
  4. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.
http://www.xdnf.cn/news/18676.html

相关文章:

  • Java中的反射机制
  • MyBatis 从入门到精通:一篇就够的实战指南(Java)
  • 3-3〔OSCP ◈ 研记〕❘ WEB应用攻击▸WEB应用安全评估工具
  • 火山引擎配置CDN
  • 【Linux | 网络】多路转接IO之poll
  • 计算机网络课堂笔记
  • AutoCAD Electrical缺少驱动程序“AceRedist“解决方法
  • C++ Core Guidelines 核心理念
  • 关于单片机串口通讯的多机操作说明---单片机串口通讯如何实现多机操作?
  • 16-day13强化学习和训练大模型
  • 怎么把iphone文件传输到windows电脑?分场景选方法
  • jasperreports 使用
  • 解锁处暑健康生活
  • 企业级监控可视化系统 Prometheus + Grafana
  • LoRA(低秩适应,Low-Rank Adaptation)的 alpha 参数解析(54)
  • 雷卯针对香橙派Orange 4G-IOT开发板防雷防静电方案
  • kafka 原理详解
  • 【OpenAI】ChatGPT-4o-latest 真正的多模态、长文本模型的详细介绍+API的使用教程!
  • 深入理解 Python Scapy 库:网络安全与协议分析的瑞士军刀
  • ES6/ES2015 - ES16/ES2025
  • 在压力测试中如何确定合适的并发用户数?
  • 挖币与区块链技术有怎样的联系?
  • 基于 Prometheus+Alertmanager+Grafana 打造监控报警后台(一)-Prometheus介绍及安装
  • DMP-Net:面向脑组织术中成像的深度语义先验压缩光谱重建方法|文献速递-深度学习人工智能医疗图像
  • PyTorch实战(1)——深度学习概述
  • 阿里:基于设计逻辑的LLM数据合成
  • crc16是什么算法
  • C++ 指针与引用面试深度解析
  • STM32项目分享:基于STM32的智能洗衣机
  • 开源大模型天花板?DeepSeek-V3 6710亿参数MoE架构深度拆解