当前位置: 首页 > web >正文

实战篇----利用 LangChain 和 BERT 用于命名实体识别-----完整代码

上一篇文章讲解了Langchain,实现一个简单的demo,结合利用 LangChain 和 BERT 用于命名实体识别。

一、命名实体识别模型训练(bert+CRF)

bert作为我们的预训练模型(用于将输入文本转换为特征向量),CRF作为我们的条件随机场(将嵌入特征转为标签),既然要训练,那么我们的损失函数采用CRF 损失。

注意区分 交叉熵损失和CRF损失

CRF本身也有学习参数,一起参与梯度更新,只是参数为一块转移矩阵实现标签之间的关系建模。

实现代码如下,

模型和 分词器都是使用的bert base chinese

实现了一个结合BERT和CRF模型的命名实体识别(NER)任务。首先,定义了BertCRF类,利用BERT进行特征提取,并通过CRF层进行序列标签预测。数据预处理部分使用BertTokenizerFast对输入文本进行分词,同时将标签对齐到子词级别,处理特殊token。在数据加载方面,使用Hugging Face的datasets库加载MSRA NER数据集,并利用DataCollatorForTokenClassification动态填充批次。

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification, DataCollatorForTokenClassification
from torchcrf import CRF
from torch.optim import AdamW
from datasets import load_dataset
from seqeval.metrics import classification_report, accuracy_score
from tqdm.auto import tqdm# 定义BERT + CRF模型
class BertCRF(nn.Module):def __init__(self, bert_model_name, num_labels):super(BertCRF, self).__init__()# 使用预训练的BERT模型进行特征提取self.bert = BertForTokenClassification.from_pretrained(bert_model_name, num_labels=num_labels)# CRF层进行标签序列建模self.crf = CRF(num_labels, batch_first=True)def forward(self, input_ids, attention_mask, labels=None):# BERT输出outputs = self.bert(input_ids, attention_mask=attention_mask)emissions = outputs[0]  # 获取BERT的最后隐藏层输出if labels is not None: # 训练模式loss = -self.crf(emissions, labels, mask=attention_mask.bool())return losselse:predictions = self.crf.decode(emissions, mask=attention_mask.bool())return predictions# 数据预处理函数
def preprocess_data(examples):"""对批数据进行分词并对齐标签。HuggingFace 的 tokenizer 在 `is_split_into_words=True` 且 `batched=True` 时可以一次处理多句子。这里根据 `word_ids(batch_index=...)` 把原始词级别标签扩展到子词级别;对特殊 token (CLS、SEP、PAD) 使用 -100,使其在计算 loss 时被忽略。`msra_ner` 数据集的 `ner_tags` 已经是整数 ID,因此无需 label2id 转换。"""# 分词tokenized = tokenizer(examples["tokens"],
http://www.xdnf.cn/news/14809.html

相关文章:

  • flask使用-链接mongoDB
  • Python爬虫-爬取汽车之家全部汽车品牌及车型数据
  • ListExtension 扩展方法增加 转DataTable()方法
  • Lua现学现卖
  • DOP数据开放平台(真实线上项目)
  • 电商返利APP架构设计:如何基于Spring Cloud构建高并发佣金结算系统
  • OpenLayers 下载地图切片
  • 解决cursor无法下载插件等网络问题
  • vue-29(创建 Nuxt.js 项目)
  • 从用户到权限:解密 AWS IAM Identity Center 的授权之道
  • 给定一个没有重复元素的数组,写出生成这个数组的MaxTree的函数
  • TDengine 如何使用 MQTT 采集数据?
  • lambda、function基础/响应式编程基础
  • [论文阅读] 软件工程 | 微前端在电商领域的实践:一项案例研究的深度解析
  • NLP中的同义词替换及我踩的坑
  • 创客匠人视角:创始人 IP 打造为何成为知识变现的核心竞争力
  • 【算法深练】单调栈:有序入栈,及时删除垃圾数据
  • 鸿蒙5:组件监听和部分状态管理V2
  • 为何需要防爆平板?它究竟有何能耐?
  • 【龙泽科技】新能源汽车故障诊断仿真教学软件【吉利几何G6】
  • 学习使用dotnet-dump工具分析.net内存转储文件(2)
  • vue-28(服务器端渲染(SSR)简介及其优势)
  • 舵机在不同类型机器人中的应用
  • Python 数据分析与可视化 Day 10 - 数据合并与连接
  • Linux的top指令CPU占用率详解(白话版)——Linux进阶常用知识点
  • 网络缓冲区
  • uni-app项目实战笔记26--uniapp实现富文本展示
  • 展开说说:Android之ContentProvider源码浅析
  • 机器学习算法-K近邻算法-KNN
  • Linux tcp_info:监控TCP连接的秘密武器