当前位置: 首页 > java >正文

医疗问答检索任务的完整 Pipeline 示例

示例代码:

# 医疗问答检索任务完整 Pipeline 示例
# 包含训练数据、retrieval、评估三步from typing import Dict, List
from collections import defaultdict
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F# 模拟的语料库(corpus)
corpus = {"d1": "糖尿病是一种慢性病,需要控制饮食和规律运动。","d2": "高血压与钠摄入量有关,应减少食盐摄入。","d3": "血糖控制可以通过口服降糖药或胰岛素治疗。","d4": "冠心病患者应避免剧烈运动。"
}# 模拟的任务数据(包含训练和评估)
task_data = [{"query": "糖尿病患者如何控制血糖?","positive": ["d1", "d3"],"negative": ["d2", "d4"]},{"query": "高血压如何治疗?","positive": ["d2"],"negative": ["d1", "d3", "d4"]}
]# 假设我们加载一个分类模型(判断是否相关)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-snli")
model.eval()# 检索函数:用模型判断每个 query 与文档是否相关
def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:inputs = tokenizer([query] * len(corpus), list(corpus.values()), padding=True, truncation=True, return_tensors="pt")with torch.no_grad():outputs = model(**inputs)probs = F.softmax(outputs.logits, dim=-1)relevance_scores = probs[:, 1]  # 假设第1类是“相关”return {k: v.item() for k, v in zip(corpus.keys(), relevance_scores)}# 构造 retrieved_results(模拟模型推理)
retrieved_results = {}
for item in task_data:query = item["query"]retrieved_results[query] = query_retrieval(query, corpus)# 构造评估用的标准答案(Ground Truth)
relevant_docs = {item["query"]: set(item["positive"]) for item in task_data}# 简单评估:计算 Recall@2
def evaluate_recall(retrieved_results: Dict[str, Dict[str, float]], relevant_docs: Dict[str, set], k: int = 2):recalls = []for query, scores in retrieved_results.items():top_k_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]top_k_doc_ids = [doc_id for doc_id, _ in top_k_docs]num_hits = len(set(top_k_doc_ids) & relevant_docs[query])recall = num_hits / len(relevant_docs[query])print(f"Query: {query}\nTop-{k}: {top_k_doc_ids}\nRecall@{k}: {recall:.2f}\n")recalls.append(recall)avg_recall = sum(recalls) / len(recalls)print(f"Average Recall@{k}: {avg_recall:.2f}")# 评估
evaluate_recall(retrieved_results, relevant_docs, k=2)

输出结果:
在这里插入图片描述

这段代码是一个医疗问答检索任务的完整 Pipeline 示例,涵盖了以下三个核心功能:


功能一:构造模拟检索任务数据(corpus 与任务集)

作用:

  • 模拟一个简单的检索场景,用于测试问答匹配系统的效果。

内容:

  • corpus:构造一个小型文档库(4条医疗相关文本),每条文本用 d1, d2, … 编号。
  • task_data:模拟用户查询,标注其相关文档(positive)和不相关文档(negative)。

🔎 示例:

task_data = [{"query": "糖尿病患者如何控制血糖?","positive": ["d1", "d3"],"negative": ["d2", "d4"]},

这表示:对这个 query,d1d3 是相关文档。


功能二:基于句对分类模型进行检索匹配

作用:

  • 使用一个现成的分类模型(比如 SNLI 领域的 BERT)来判断 query 和 每篇文档 的相关性打分。

核心函数:

def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:

步骤详细解释:

  1. Tokenizer 编码: query 与每个 corpus 文档形成句对,用 BERT 的 tokenizer 编码。
  2. 模型推理: 输入模型做前向传播,得到每个句对的分类 logits。
  3. Softmax 得分: 用 softmax 转换为概率,假设“相关”是第1类(probs[:, 1])。
  4. 返回结果: 返回一个字典,每个文档对应一个相关性得分。

🔎 示例:

{'d1': 0.85, 'd2': 0.12, 'd3': 0.78, 'd4': 0.09}

表示该 query 与 d1, d3 最相关。


功能三:评估模型效果(Recall@k)

作用:

  • 检查模型检索的 top-k 结果中,有多少比例是标注为相关的文档。

评估函数:

def evaluate_recall(retrieved_results, relevant_docs, k=2)

步骤解释:

  1. 提取 top-k: 对每个 query 按得分排序,取前 k 个文档。
  2. 计算 Recall: Recall@k = (top-k 中相关文档数量) / (所有标注相关文档数量)
  3. 打印和平均: 每个 query 的 Recall@k 打印出来,同时统计平均值。

🔄 总结:整个流程实现了什么?

步骤描述
1️⃣ 数据准备构造 query、文档、标注对
2️⃣ 检索利用分类模型计算 query 与文档的相关性得分
3️⃣ 排序与评估取 top-k 文档,计算 recall,衡量模型效果

下面对代码部分进行详细解释

数据准备、检索到评估,我将按模块逐行详细解释


一、导入库和模型

from typing import Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

含义说明:

  • Dict, List: 类型注解,说明变量的数据结构。
  • torch: PyTorch 深度学习库。
  • transformers: HuggingFace 提供的 Transformers 模型接口。
  • F.softmax: 用于将模型输出转为“概率”。

二、模拟语料库(corpus)

corpus = {"d1": "糖尿病是一种慢性病,需要控制饮食和规律运动。","d2": "高血压与钠摄入量有关,应减少食盐摄入。","d3": "血糖控制可以通过口服降糖药或胰岛素治疗。","d4": "冠心病患者应避免剧烈运动。"
}

含义:

  • 这是一个简单的文档库,每个文档有个 id(如 d1),和一段医学相关的文本内容。

三、任务数据(模拟的训练/评估样本)

task_data = [{"query": "糖尿病患者如何控制血糖?","positive": ["d1", "d3"],"negative": ["d2", "d4"]},{"query": "高血压如何治疗?","positive": ["d2"],"negative": ["d1", "d3", "d4"]}
]

含义:

  • 模拟用户提出的问题(query)。
  • positive: 该问题对应的正确文档 id。
  • negative: 与问题无关的文档 id。

例如:

  • 问题:“糖尿病患者如何控制血糖?”

    • 正确答案文档:d1(控制饮食)和 d3(药物治疗)。
    • 错误文档:d2(高血压)和 d4(冠心病)。

四、加载预训练模型

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-snli")
model.eval()

含义:

  • 加载 BERT 模型(用于句子对判断:是否相关)。
  • tokenizer: 把文本转成模型输入格式。
  • model.eval(): 设置模型为评估模式,不进行梯度更新。

注意:这里用的模型是 snli 任务(自然语言推理),它学会了判断两个句子是否有关系,非常适合做句子对的匹配任务。


五、检索函数:判断 Query 和每篇文档是否相关

def query_retrieval(query: str, corpus: Dict[str, str]) -> Dict[str, float]:inputs = tokenizer([query] * len(corpus), list(corpus.values()), padding=True, truncation=True, return_tensors="pt")with torch.no_grad():outputs = model(**inputs)probs = F.softmax(outputs.logits, dim=-1)relevance_scores = probs[:, 1]  # 假设第1类是“相关”return {k: v.item() for k, v in zip(corpus.keys(), relevance_scores)}

详细解释:

  • 输入:一个 query,比如“糖尿病怎么控制血糖?”
  • 针对每个文档,都和 query 配成一个句子对(比如“糖尿病控制”和 d1 文本)。
  • 使用 tokenizer 编码成批输入,送入模型。
  • 得到模型对每个句子对的预测 logits,softmax 转为概率。
  • probs[:, 1]: 取“相关”这个类别的概率,作为得分。

🔍 例子:

query = "糖尿病患者如何控制血糖?"
文档 d1 = "糖尿病是一种慢性病,需要控制饮食和规律运动。"
→ BERT 判断是否相关,输出一个概率,比如 0.88

六、生成每个 Query 的检索结果

retrieved_results = {}
for item in task_data:query = item["query"]retrieved_results[query] = query_retrieval(query, corpus)
  • 针对每个问题 query,调用上面定义的 query_retrieval,获取所有文档的相关性得分。

🔍 结果类似于:

{"糖尿病患者如何控制血糖?": {"d1": 0.88, "d2": 0.1, ...},...
}

七、构造标准答案(Ground Truth)

relevant_docs = {item["query"]: set(item["positive"]) for item in task_data}
  • 把每个 query 的 positive 列表转成集合,作为真实相关文档。

八、评估:Recall@K

def evaluate_recall(retrieved_results: Dict[str, Dict[str, float]], relevant_docs: Dict[str, set], k: int = 2):recalls = []for query, scores in retrieved_results.items():top_k_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]top_k_doc_ids = [doc_id for doc_id, _ in top_k_docs]num_hits = len(set(top_k_doc_ids) & relevant_docs[query])recall = num_hits / len(relevant_docs[query])print(f"Query: {query}\nTop-{k}: {top_k_doc_ids}\nRecall@{k}: {recall:.2f}\n")recalls.append(recall)avg_recall = sum(recalls) / len(recalls)print(f"Average Recall@{k}: {avg_recall:.2f}")

详细解释:

  • 对每个 query,从检索结果中取出得分最高的前 k 篇文档。
  • 看这 k 篇文档里,有多少是正确答案(交集)。
  • Recall 计算方式:命中的正例数 / 实际正例数
  • 最后输出平均 Recall。

🔍 例子:

Query: 糖尿病如何控制血糖?
Top-2: ["d3", "d1"]  # 全部命中
Recall@2 = 2 / 2 = 1.00

九、运行评估

evaluate_recall(retrieved_results, relevant_docs, k=2)

会输出类似:

Query: 糖尿病患者如何控制血糖?
Top-2: ['d3', 'd1']
Recall@2: 1.00Query: 高血压如何治疗?
Top-2: ['d2', 'd4']
Recall@2: 1.00Average Recall@2: 1.00

总结:整个流程干了什么?

  1. 定义语料库和训练样本
  2. 用 BERT 模型对 query 和文档进行配对判断
  3. 打分并选出相关文档
  4. 根据真实正例,计算 Recall@K 评估效果
http://www.xdnf.cn/news/4834.html

相关文章:

  • 又双叒叕想盘一下systemd
  • 中小企业设备预测性维护三步构建法:从零到精的技术跃迁与中讯烛龙实践
  • BUUCTF——杂项渗透之1和0的故事
  • 6. 进程控制
  • 基于51单片机的自动洗衣机衣料材质proteus仿真
  • 冯诺依曼体系结构与操作系统
  • 2.6 点云数据存储格式——小结
  • 1128. 等价多米诺骨牌对的数量
  • Python Cookbook-7.7 通过 shelve 修改对象
  • HPLC+HRF双模载波组网过程简析
  • 【嵌入式开发-SDIO】
  • 前端获取流式数据并输出
  • 【Day 22】HarmonyOS车联网开发实战
  • vfrom表单设计器使用事件机制控制字段显示隐藏
  • 算法解密:除自身以外数组的乘积问题详解
  • robot_lab中amp_utils——retarget_kp_motions.py解析
  • 算法训练营第十一天|150. 逆波兰表达式求值、239. 滑动窗口最大值、347.前 K 个高频元素
  • 旅游设备生产企业的痛点 质检系统在旅游设备生产企业的应用
  • Python pandas 向excel追加数据,不覆盖之前的数据
  • <C#>log4net 的配置文件配置项详细介绍
  • python24-匿名函数
  • 2.5 特征值与特征向量
  • 二叉树的基本操作
  • es6/7练习题1
  • 微软推动智能体协同运作:支持 A2A、MCP 协议
  • mqtt选型,使用
  • 关键字where
  • Docker学习笔记
  • deeplabv3+街景图片语义分割,无需训练模型,看不懂也没有影响,直接使用,cityscapes数据集_25
  • python小说网站管理系统-小说阅读系统