NLP:LSTM和GRU分享
本文目录:
- 一、 前置知识:RNN的痛点
- 二、LSTM(长短期记忆网络)
- (一)遗忘门(Forget Gate)
- (二)输入门(Input Gate)
- (三)输出门(Output Gate)
- (四)细胞状态(Cell State)
- (五)使用Pytorch构建LSTM模型
- 三、 GRU(门控循环单元)
- (一)更新门(Update Gate)
- (二)重置门(Reset Gate)
- (三)使用Pytorch构建GRU模型
- 文末附赠:
- (一)传统RNN、 LSTM 和 GRU 三者核心结构对比
- (二)传统RNN、 LSTM 和 GRU 三者性能表现对比
- (三)传统RNN、 LSTM 和 GRU 三者应用场景对比
前言:前面文章分享了传统RNN,此次分享传统RNN变体:LSTM和GRU。
一、 前置知识:RNN的痛点
传统RNN像金鱼记忆:
只能记住最近几步的信息(梯度消失/爆炸)
遇到长序列就懵圈(“开头说了啥来着?”)
二、LSTM(长短期记忆网络)
核心设计:记忆管控大师
想象LSTM是个图书馆管理员,它有三把钥匙:
(一)遗忘门(Forget Gate)
决定哪些旧记忆该丢弃:
“上个月的天气预报数据?可以忘了。”
(二)输入门(Input Gate)
决定哪些新信息值得记录:
“今天突然下冰雹?这个得重点记!”
(三)输出门(Output Gate)
决定当前输出什么信息:
“根据天气记录,建议你带伞。”
(四)细胞状态(Cell State)
像传送带,专门运输长期记忆。
公式精华:
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # 遗忘门
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # 输入门
C̃_t = tanh(W_C · [h_{t-1}, x_t] + b_C) # 候选记忆
C_t = f_t * C_{t-1} + i_t * C̃_t # 更新细胞状态
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # 输出门
h_t = o_t * tanh(C_t) # 最终输出
(五)使用Pytorch构建LSTM模型
位置: 在torch.nn工具包之中, 通过torch.nn.LSTM可调用。
nn.LSTM类初始化主要参数解释: input_size: 输入张量x中特征维度的大小。
hidden_size: 隐层张量h中特征维度的大小., num_layers: 隐含层的数量。
bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用。
nn.LSTM类实例化对象主要参数解释: input: 输入张量x、h0: 初始化的隐层张量h、c0: 初始化的细胞状态张量c。
nn.LSTM使用示例:
# 定义LSTM的参数含义: (input_size, hidden_size, num_layers)
# 定义输入张量的参数含义: (sequence_length, batch_size, input_size)
# 定义隐藏层初始张量和细胞初始状态张量的参数含义:
# (num_layers * num_directions, batch_size, hidden_size)>>> import torch.nn as nn
>>> import torch
>>> rnn = nn.LSTM(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> c0 = torch.randn(2, 3, 6)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> output
tensor([[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.4647, -0.2364, 0.0645, -0.3996, -0.0500, -0.0152],[ 0.3852, 0.0704, 0.2103, -0.2524, 0.0243, 0.0477],[ 0.2571, 0.0608, 0.2322, 0.1815, -0.0513, -0.0291]],[[ 0.0447, -0.0335, 0.1454, 0.0438, 0.0865, 0.0416],[ 0.0105, 0.1923, 0.5507, -0.1742, 0.1569, -0.0548],[-0.1186, 0.1835, -0.0022, -0.1388, -0.0877, -0.4007]]],grad_fn=<StackBackward>)
>>> cn
tensor([[[ 0.8083, -0.5500, 0.1009, -0.5806, -0.0668, -0.1161],[ 0.7438, 0.0957, 0.5509, -0.7725, 0.0824, 0.0626],[ 0.3131, 0.0920, 0.8359, 0.9187, -0.4826, -0.0717]],[[ 0.1240, -0.0526, 0.3035, 0.1099, 0.5915, 0.0828],[ 0.0203, 0.8367, 0.9832, -0.4454, 0.3917, -0.1983],[-0.2976, 0.7764, -0.0074, -0.1965, -0.1343, -0.6683]]],grad_fn=<StackBackward>)
特别分享:什么是Bi-LSTM ?- Bi-LSTM即双向LSTM, 它没有改变LSTM本身任何的内部结构, 只是将LSTM应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出。
三、 GRU(门控循环单元)
核心设计:LSTM的极简版
GRU像效率至上的程序员,合并了LSTM的门:
(一)更新门(Update Gate)
二合一:同时控制遗忘和输入
“旧记忆留多少?新记忆收多少?这门说了算!”
(二)重置门(Reset Gate)
决定多少过去信息用于计算新状态:
“昨天的股票数据对预测今天有用吗?”
独特优势:
参数比LSTM少(训练更快)
公式精华:
z_t = σ(W_z · [h_{t-1}, x_t]) # 更新门
r_t = σ(W_r · [h_{t-1}, x_t]) # 重置门
h̃_t = tanh(W · [r_t * h_{t-1}, x_t]) # 候选状态
h_t = (1-z_t) * h_{t-1} + z_t * h̃_t # 最终状态
(三)使用Pytorch构建GRU模型
位置: 在torch.nn工具包之中, 通过torch.nn.GRU可调用。
nn.GRU类初始化主要参数解释: input_size: 输入张量x中特征维度的大小。
hidden_size: 隐层张量h中特征维度的大小, num_layers: 隐含层的数量。
bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用。
nn.GRU类实例化对象主要参数解释: * input: 输入张量x. * h0: 初始化的隐层张量h。
nn.GRU使用示例:
>>> import torch
>>> import torch.nn as nn
>>> rnn = nn.GRU(5, 6, 2)
>>> input = torch.randn(1, 3, 5)
>>> h0 = torch.randn(2, 3, 6)
>>> output, hn = rnn(input, h0)
>>> output
tensor([[[-0.2097, -2.2225, 0.6204, -0.1745, -0.1749, -0.0460],[-0.3820, 0.0465, -0.4798, 0.6837, -0.7894, 0.5173],[-0.0184, -0.2758, 1.2482, 0.5514, -0.9165, -0.6667]]],grad_fn=<StackBackward>)
>>> hn
tensor([[[ 0.6578, -0.4226, -0.2129, -0.3785, 0.5070, 0.4338],[-0.5072, 0.5948, 0.8083, 0.4618, 0.1629, -0.1591],[ 0.2430, -0.4981, 0.3846, -0.4252, 0.7191, 0.5420]],[[-0.2097, -2.2225, 0.6204, -0.1745, -0.1749, -0.0460],[-0.3820, 0.0465, -0.4798, 0.6837, -0.7894, 0.5173],[-0.0184, -0.2758, 1.2482, 0.5514, -0.9165, -0.6667]]],grad_fn=<StackBackward>)
特别分享:什么是Bi-GRU ?Bi-GRU与Bi-LSTM的逻辑相同, 都是不改变其内部结构, 而是将模型应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出。
文末附赠:
(一)传统RNN、 LSTM 和 GRU 三者核心结构对比
(二)传统RNN、 LSTM 和 GRU 三者性能表现对比
(三)传统RNN、 LSTM 和 GRU 三者应用场景对比
今天的分享到此结束(改用了一种更温馨的文体~希望大家喜欢(╹▽╹)