Langgraph实战-自省式RAG: Self-RAG
Langgraph实战-自省式RAG: Self-RAG
概述
传统 LLM 仅依赖内部参数化知识生成回答,容易产生事实性错误(hallucination)。即使结合检索增强生成(RAG),也存在两个问题:(1)盲目检索:无论是否需要检索,都固定检索一定数量的文档,可能引入无关信息。(2)机械整合:直接拼接检索到的内容,可能降低生成结果的流畅性或实用性。
现有 RAG 也存在很多的不足,例如:无法动态判断何时需要检索、检索内容是否相关,导致生成质量不稳定。
Self-RAG介绍
在检索增强生成(RAG)中,典型的工作流程包括根据用户输入查询向量存储以检索相关文档,然后使用语言模型生成响应。而Self-RAG 在此框架的基础上,将 “自省(reflection)”步骤纳入了 RAG。这种方法可让 LLM 控制检索时间、评估相关性、批判其响应等方面,并调整其行为以提高响应的准确性和真实性。
自省检索增强生成(Self-RAG),该框架通过检索与自省机制提升语言模型的质量与事实性。Self-RAG训练单一通用模型,使其能够:
- 自省检索与生成:
- 模型通过训练学会 按需检索(on-demand retrieval),仅在必要时触发检索机制。
- 通过引入 自省标记(reflection tokens),模型能对检索到的内容和自身生成的内容进行自我评估,动态调整生成策略。
- 可控性增强:
- 自省标记使模型在推理阶段具备可控性,例如:
- 标记是否需要检索(如
Retrieve: Yes/No
)。 - 评估检索内容的相关性(如
Relevance: High/Medium/Low
)。 - 验证生成内容的支持性(如
Support: Fully/Partially
)。
- 标记是否需要检索(如
- 用户可通过控制这些标记调整模型行为,适应不同任务需求(如严格事实核查或创意生成)。
- 自省标记使模型在推理阶段具备可控性,例如:
Self-RAG自我评估机制的说明
传统的 RAG 管道会检索一定数量的文档,而不考虑相关性或在特定上下文中的检索需求。Self-RAG 通过允许模型决定是否有必要检索、包含哪些段落以及何时批判或改进响应来增强这一功能。这种自省方法使得 LangGraph 中的工作流程具有动态性和适应性。
Self-RAG 使用自省标记符来标记检索和生成过程中的各种决定:
- Retrieve Token:决定是否进行检索。
- ISREL Token:评估检索到的文档是否与问题相关。
- ISSUP Token:确保生成的答案是以检索到的文档为基础的。
- ISUSE Token:衡量生成答案的整体有用性。
这些Token使模型具有自适应能力,能够即时改进其行为,以最大限度地提高答案的相关性、事实性和上下文的准确性。
LangGraph 中的自Self-RAG 工作流 在自反式 RAG 中,LangGraph 支持包含决策点和反馈回路的自反式工作流(self-reflective),并通过 LangGraph 的状态机来实现。下面,通过一个例子来说明self-reflective RAG工作流的实现。每个节点只需修改状态。每条边都会选择下一个要调用的节点。
实现逻辑图如下:
- 下载文档,并对文档进行切割,构建文档快索引
- 根据问题检索并获取相关文档。
- 对检索到的文档进行等级分级,对文件相关性进行评估。
- 根据对检索到的文档的分级,决定哪些文档进入生成(generate)环节。根据相关文档生成回答。
- 检查生成的内容是否有用,若没有用的话,继续获取文档,或者重写问题。
代码实现
实现代码如下:
import os
from typing import List, TypedDict
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from langchain_core.pydantic_v1 import BaseModel, Field
from display_graph import display_graph
from langchain_core.embeddings import Embeddings
from langchain_community.vectorstores import FAISS
import requests
from dotenv import load_dotenvload_dotenv()# 1. 文件索引 加载多个来源,创建丰富的数据库,为在 FAISS 向量数据库中检索做好准备。
# Load and prepare documentsurls = ["https://lilianweng.github.io/posts/2023-06-23-agent/","https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/","https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/"
]docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
# 对文档进行切割
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=250, chunk_overlap=0)
doc_splits = text_splitter.split_documents(docs_list)# 基于嵌入向量构建索引对象
vectorstore = FAISS.from_documents(documents=doc_splits, embedding=CustomSiliconFlowEmbeddings())
retriever = vectorstore.as_retriever()# Set up prompt and model
prompt = ChatPromptTemplate.from_template("""
Use the following context to answer the question concisely:
Question: {question}
Context: {context}
Answer:
""")# 构建大模型
local_llm = "Qwen/Qwen2.5-7B-Instruct"
base_url = "https://api.siliconflow.cn/v1"
SL_API_KEY = os.getenv("SL_API_KEY")model = ChatOpenAI(model=local_llm, base_url=base_url, api_key=SL_API_KEY)#
# 2.实施自适应式工作流程步骤 为:
# 检索(retrieval)->生成(generation)->分级(grading)->查询转换(query transformation)定义节点。
#
rag_chain = (prompt | model | StrOutputParser())class GraphState(TypedDict):question: strgeneration: strdocuments: List[str]# 评估文档相关性
class GradeDocuments(BaseModel):binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")retrieval_prompt = ChatPromptTemplate.from_template("""
You are a grader assessing if a document is relevant to a user's question.
Document: {document}
Question: {question}
Is the document relevant? Answer 'yes' or 'no'.
""")
retrieval_grader = retrieval_prompt | model.with_structured_output(GradeDocuments)# 检查生成内容是否基于文档
class GradeHallucinations(BaseModel):binary_score: str = Field(description="Answer is grounded in the documents, 'yes' or 'no'")hallucination_prompt = ChatPromptTemplate.from_template("""
You are a grader assessing if an answer is grounded in retrieved documents.
Documents: {documents}
Answer: {generation}
Is the answer grounded in the documents? Answer 'yes' or 'no'.
""")
hallucination_grader = hallucination_prompt | model.with_structured_output(GradeHallucinations)# 评估答案是否回答了问题
class GradeAnswer(BaseModel):binary_score: str = Field(description="Answer addresses the question, 'yes' or 'no'")answer_prompt = ChatPromptTemplate.from_template("""
You are a grader assessing if an answer addresses the user's question.
Question: {question}
Answer: {generation}
Does the answer address the question? Answer 'yes' or 'no'.
""")
answer_grader = answer_prompt | model.with_structured_output(GradeAnswer)# 定义LangGraph函数
def retrieve(state):question = state["question"]documents = retriever.invoke(question)return {"documents": documents, "question": question}def generate(state):question = state["question"]documents = state["documents"]generation = rag_chain.invoke({"context": documents, "question": question})return {"documents": documents, "question": question, "generation": generation}def grade_documents(state):"""根据和问题的相关性对文档进行评分。仅仅和问题相关的文档保留到relevant_docs中。Grades documents based on relevance to the question.Only relevant documents are retained in 'relevant_docs'."""question = state["question"]documents = state["documents"]relevant_docs = []for doc in documents:response = retrieval_grader.invoke({"question": question, "document": doc.page_content})if response.binary_score == "yes":relevant_docs.append(doc)return {"documents": relevant_docs, "question": question}def decide_to_generate(state):"""决定是继续生成查询还是转换查询。Decides whether to proceed with generation or transform the query."""if not state["documents"]:return "transform_query" # No relevant docs found; rephrase queryreturn "generate" # Relevant docs found; proceed to generatedef grade_generation_v_documents_and_question(state):"""检查生成是否基于检索到的文档并回答问题。Checks if the generation is grounded in retrieved documents and answers the question."""question = state["question"]documents = state["documents"]generation = state["generation"]# 1: 检查生成的内容是否基于检索到的文档hallucination_check = hallucination_grader.invoke({"documents": documents, "generation": generation})if hallucination_check.binary_score == "no":return "not supported" # Regenerate if generation isn't grounded in documents# Step 2: Check if generation addresses the question# 2: 检查生成的内容是否回答了问题answer_check = answer_grader.invoke({"question": question, "generation": generation})return "useful" if answer_check.binary_score == "yes" else "not useful"def transform_query(state):"""如果初始尝试未生成相关文档,则重新表述查询以改进检索效果。Rephrases the query for improved retrieval if initial attempts do not yield relevant documents."""transform_prompt = ChatPromptTemplate.from_template("""You are a question re-writer that converts an input question to a better version optimized for retrieving relevant documents.Original question: {question} Please provide a rephrased question.""")question_rewriter = transform_prompt | model | StrOutputParser()question = state["question"]# Rephrase the question using LLM# 使用 LLM 重新表述问题transformed_question = question_rewriter.invoke({"question": question})return {"question": transformed_question, "documents": state["documents"]}#
# 工作流程设置 在 LangGraph 中,我们定义了图的状态、边和条件节点,以建立自适应反馈回路:
#
# Set up the workflow graph
workflow = StateGraph(GraphState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("transform_query", transform_query)
workflow.add_edge(START, "retrieve")
# 初始查询流程
workflow.add_edge("retrieve", "grade_documents")
# 生成答案评估
workflow.add_conditional_edges("grade_documents", decide_to_generate, {"transform_query": "transform_query", "generate": "generate"})
# 如果检索的文档不相关,通过 transform_query 重新构造查询
workflow.add_edge("transform_query", "retrieve")
# 如果生成的答案不够好,可以重新生成
# 如果答案没有很好回答问题,可以重新构造查询的问题
workflow.add_conditional_edges("generate", grade_generation_v_documents_and_question, {"not supported": "generate", "useful": END, "not useful": "transform_query"})# Compile the app and run
app = workflow.compile()# Display the graph
display_graph(app, file_name=os.path.basename(__file__))# Example input
inputs = {"question": "Explain how the different types of agent memory work?"}
for output in app.stream(inputs):print(output)
嵌入模型调用类的设计
#
# 自定义嵌入模型。这里我不是用openaiembeding,而是用SiliconFlow的API来获取嵌入。
#
class CustomSiliconFlowEmbeddings(Embeddings):def __init__(self,api_key: str = os.getenv("SL_API_KEY"),base_url: str = "https://api.siliconflow.cn/v1/embeddings",model: str = "BAAI/bge-large-zh-v1.5"):self.api_key = api_keyself.base_url = base_urlself.model = modeldef embed_documents(self, texts: List[str]) -> List[List[float]]:"""Embed a list of documents."""embeddings = []for text in texts:embedding = self.embed_query(text)embeddings.append(embedding)return embeddingsdef embed_query(self, text: str) -> List[float]:"""Embed a query."""headers = {"Authorization": f"Bearer {self.api_key}","Content-Type": "application/json"}payload = {"model": self.model,"input": text,"encoding_format": "float"}response = requests.post(self.base_url,json=payload,headers=headers)if response.status_code == 200:return response.json()["data"][0]["embedding"]else:raise Exception(f"Error in embedding: {response.text}")
总结
通过Self-RAG增强了RAG的效果,让RAG的结果更加有效。但是,虽然Self-RAG 增强了 RAG 工作流程的适应性,但也带来了一些复杂性: (1)延迟: 反射周期导致处理时间增加。(2)资源使用: 迭代反射需要额外的计算资源。(2)实施复杂性: 设计具有反馈回路的自适应工作流需要仔细调整。
参考资料
- https://github.com/langchain-ai/langchain/blob/master/cookbook/langgraph_self_rag.ipynb
- https://arxiv.org/abs/2310.11511
- https://langchain-ai.github.io/langgraph/