当前位置: 首页 > java >正文

双向长短期记忆网络-BiLSTM

5月14日复盘

二、BiLSTM

1. 概述

双向长短期记忆网络(Bi-directional Long Short-Term Memory,BiLSTM)是一种扩展自长短期记忆网络(LSTM)的结构,旨在解决传统 LSTM 模型只能考虑到过去信息的问题。BiLSTM 在每个时间步同时考虑了过去和未来的信息,从而更好地捕捉了序列数据中的双向上下文关系。

BiLSTM 的创新点在于引入了两个独立的 LSTM 层,一个按正向顺序处理输入序列,另一个按逆向顺序处理输入序列。这样,每个时间步的输出就包含了当前时间步之前和之后的信息,进而使得模型能够更好地理解序列数据中的语义和上下文关系。

  • 正向传递: 输入序列按照时间顺序被输入到第一个LSTM层。每个时间步的输出都会被计算并保留下来。

  • 反向传递: 输入序列按照时间的逆序(即先输入最后一个元素)被输入到第二个LSTM层。与正向传递类似,每个时间步的输出都会被计算并保留下来。

  • 合并输出: 在每个时间步,将两个LSTM层的输出通过某种方式合并(如拼接或加和)以得到最终的输出。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2. BILSTM模型应用背景

命名体识别

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

标注集

BMES标注集

分词的标注集并非只有一种,举例中文分词的情况,汉子作为词语开始Begin,结束End,中间Middle,单字Single,这四种情况就可以囊括所有的分词情况。于是就有了BMES标注集,这样的标注集在命名实体识别任务中也非常常见。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

词性标注

在序列标注问题中单词序列就是x,词性序列就是y,当前词词性的判定需要综合考虑前后单词的词性。而标注集最著名的就是863标注集和北大标注集。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3. 代码实现

原生代码

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)class GRU:def __init__(self, input_size, hidden_size, output_size):self.input_size = input_sizeself.hidden_size = hidden_sizeself.output_size = output_size#权重矩阵和偏置self.W_z = np.random.randn(hidden_size + input_size, hidden_size)self.b_z = np.zeros((hidden_size,))self.W_r = np.random.randn(hidden_size + input_size, hidden_size)self.b_r = np.zeros((hidden_size,))# ht候选self.W = np.random.randn(hidden_size + input_size, hidden_size)self.b = np.zeros((hidden_size,))def forward(self, x, h_last):""":param x: [s,dim]:param h_last::return:"""# 初始化状态h_prev = np.zeros((self.hidden_size,))h_all = []for i in range(x.shape[0]):x_t = x[i]x_t_h_prev = np.concatenate((x_t, h_prev), axis=0)r_t = sigmoid(np.dot(x_t_h_prev, self.W_r) + self.b_r)z_t = sigmoid(np.dot(x_t_h_prev, self.W_z) + self.b_z)# h_prev = r_t * h_prevh_t_input = np.concatenate((x_t, h_prev * r_t), axis=0)h_t_candidate = tanh(np.dot(h_t_input, self.W) + self.b)h_t = (1 - z_t) * h_prev + z_t * h_t_candidateh_all.append(h_t)return h_allif __name__ == '__main__':gru = GRU(input_size=2, hidden_size=5, output_size=1)x = np.random.randn(3 , 2)h_last = np.zeros((3,))h_all = gru.forward(x, h_last)print(h_all)
# ---------------------------------------------------------------------------
import numpy as np# 创建一个包含两个二维数组的列表
inputs = [np.array([[0.1], [0.2], [0.3]]), np.array([[0.4], [0.5], [0.6]])]# 使用 numpy 库中的 np.stack 函数。这会将输入的二维数组堆叠在一起,从而形成一个新的三维数组
inputs_3d = np.stack(inputs)# 将三维数组转换为列表
list_from_3d_array = inputs_3d.tolist()print(list_from_3d_array)

Pytorch

import torch
import torch.nn as nn# 模型参数设置
batch_size = 10
sen_len = 6
hidden_size = 8input_size = 3
output_size = hidden_size * 2  # 类别是隐藏层大小的两倍# 初始化隐藏层状态
h_prev = torch.zeros(1, batch_size, hidden_size)# RNN调用
model = nn.GRU(input_size, hidden_size, batch_first=True)
fc = nn.Linear(hidden_size, output_size)  # 全连接层用于分类# 初始化数据
x = torch.randn(10, 6, 3)out, h_next = model(x, h_prev)
# 对每个时间步的输出进行分类
out = out.contiguous().view(-1, hidden_size)  # 调整形状为 (batch_size * sen_len, hidden_size)
out = fc(out)
out = out.view(batch_size, sen_len, output_size)  # 调整回 (batch_size, sen_len, output_size)print("多对多输出:")
print(out.shape)
print(out)
print(h_next.shape)
print(h_next)out, h_next = model(x, h_prev)
# 只对最后一个时间步的输出进行分类
final_out = h_next.squeeze(0)  # 移除多余的维度,得到 (batch_size, hidden_size)
final_out = fc(final_out)print("\n多对一输出:")
print(final_out.shape)
print(final_out)
print(h_next.shape)
print(h_next)
http://www.xdnf.cn/news/6024.html

相关文章:

  • 鸿蒙OSUniApp打造多功能图表展示组件 #三方框架 #Uniapp
  • 行项目违反范围截止值
  • electron结合vue,直接访问静态文件如何跳转访问路径
  • 【IPMV】图像处理与机器视觉:Lec11 Keypoint Features and Corners
  • 以太网供电(PoE)交换机与自愈网络功能:打卡系统的得力助手
  • 基于 Spring Boot 瑞吉外卖系统开发(十四)
  • Vue 和 React 状态管理的性能优化策略对比
  • 数据结构中的高级排序算法
  • Linux内核可配置的参数
  • 单片机-STM32部分:14、SPI
  • 查询公网IP地址的方法:查看自己是不是公网ip,附内网穿透外网域名访问方案
  • 构建优雅对象的艺术:Java 建造者模式的架构解析与工程实践
  • HarmonyOs开发之———使用HTTP访问网络资源
  • Eslint和perrier的作用
  • CSS盒子模型:Padding与Margin的适用场景与注意事项
  • npm 报错 gyp verb `which` failed Error: not found: python2 解决方案
  • 【漫话机器学习系列】259.神经网络参数的初始化(Initialization Of Neural Network Parameters)
  • 【Java面试题】——this 和 super 的区别
  • PHP黑白胶卷底片图转彩图功能 V2025.05.15
  • Stable Diffusion WebUI 插件大全:功能详解与下载地址
  • 【软件测试】:推荐一些接口与自动化测试学习练习网站(API测试与自动化学习全攻略)
  • 配置Nginx解决http host头攻击漏洞【详细步骤】
  • Dockerfile实战:从零构建自定义CentOS镜像
  • Python爬虫实战:研究进制流数据,实现逆向解密
  • 【优选算法 | 字符串】字符串模拟题精选:思维+实现解析
  • 【python实用小脚本-59】连续刷题7天,手动整理编程题目效率低下,Python代码5分钟搞定,效率提升80%(附方案)
  • 力扣刷题Day 48:盛最多水的容器(283)
  • Linux操作系统中的SOCKET相关 - Socket字节序调整与网络传输
  • Kubernetes 标签和注解
  • 【软件测试】第一章·软件测试概述