LamaIndex rag(增强检索)入门
LamaIndex RAG 搭建
lamaindex rag 的简单是例
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings,SimpleDirectoryReader,VectorStoreIndex
from llama_index.llms.huggingface import HuggingFaceLLM#初始化一个HuggingFaceEmbedding对象,用于将文本转换为向量表示
embed_model = HuggingFaceEmbedding(#指定了一个预训练的sentence-transformer模型的路径model_name="/home/cw/llms/embedding_model/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)#将创建的嵌入模型赋值给全局设置的embed_model属性,这样在后续的索引构建过程中,就会使用这个模型
Settings.embed_model = embed_model#使用HuggingFaceLLM加载本地大模型
llm = HuggingFaceLLM(model_name="/home/cw/llms/Qwen/Qwen1.5-1.8B-Chat",tokenizer_name="/home/cw/llms/Qwen/Qwen1.5-1.8B-Chat",model_kwargs={"trust_remote_code":True},tokenizer_kwargs={"trust_remote_code":True})
#设置全局的llm属性,这样在索引查询时会使用这个模型。
Settings.llm = llm#从指定目录读取文档,将数据加载到内存
documents = SimpleDirectoryReader("/home/cw/projects/demo_17/data").load_data()
# print(documents)
#创建一个VectorStoreIndex,并使用之前加载的文档来构建向量索引
#此索引将文档转换为向量,并存储这些向量(内存)以便于快速检索
index = VectorStoreIndex.from_documents(documents)
#创建一个查询引擎,这个引擎可以接收查询并返回相关文档的响应。
query_engine = index.as_query_engine()
rsp = query_engine.query("xtuner是什么?")
print(rsp)
2:chroma 向量数据库的使用
import chromadb
from sentence_transformers import SentenceTransformerclass SentenceTransformerEmbeddingFunction:def __init__(self, model_path: str, device: str = "cuda"):self.model = SentenceTransformer(model_path, device=device)def __call__(self, input: list[str]) -> list[list[float]]:if isinstance(input, str):input = [input]return self.model.encode(input, convert_to_numpy=True).tolist()# 创建/加载集合(含自定义嵌入函数)
embed_model = SentenceTransformerEmbeddingFunction(model_path="/root/models/BAAI/bge-m3",device="cuda" # 无 GPU 改为 "cpu"
)# 创建客户端和集合
client = chromadb.Client()
collection = client.create_collection("my_knowledge_base", metadata={"hnsw:space": "cosine"},embedding_function=embed_model)# 添加文档
collection.add(documents=["RAG是一种检索增强生成技术", "向量数据库存储文档的嵌入表示"],metadatas=[{"source": "tech_doc"}, {"source": "tutorial"}, {"source": "tutorial1"}],ids=["doc1", "doc2", "doc3"]
)# 查询相似文档
results = collection.query(query_texts=["什么是RAG技术?"],n_results=3
)print(results)collection.update(ids=["doc1"], # 使用已存在的IDdocuments=["更新后的RAG技术内容"]
)# 查看更新后的内容 - 方法1:使用 get() 获取特定 ID 的内容
updated_docs = collection.get(ids=["doc1"])
print("更新后的文档内容:", updated_docs["documents"])# 查看更新后的内容 - 方法2:查询所有文档
all_docs = collection.get()
print("集合中所有文档:", all_docs["documents"])# 删除内容
collection.delete(ids=["doc1"])# 查看更新后的内容
all_docs = collection.get()
print("集合中所有文档:", all_docs["documents"])# 统计条目
print(collection.count())
####3. 对话模版的定义
QA_TEMPLATE = ("<|im_start|>system\n""你是一个专业的法律助手,请严格根据以下法律条文回答问题:\n""相关法律条文:\n{context_str}\n<|im_end|>\n""<|im_start|>user\n{query_str}<|im_end|>\n""<|im_start|>assistant\n"
)response_template = PromptTemplate(QA_TEMPLATE)
4. chormdb 数据库的持久化及数据读取
import chromadb
from llama_index.core import VectorStoreIndex, StorageContext, Settings, get_response_synthesizer
################### 数据持久化####################
# 创建client
chroma_client = chromadb.PersistentClient(path=Config.VECTOR_DB_DIR)
# 创建数据集
chroma_collection = chroma_client.get_or_create_collection(name=Config.COLLECTION_NAME,metadata={"hnsw:space": "cosine"})# 确保存储上下文正确初始化
storage_context = StorageContext.from_defaults(vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
)
# 显式将节点添加到存储上下文
storage_context.docstore.add_documents(nodes) index = VectorStoreIndex(nodes,storage_context=storage_context,show_progress=True
)
# 双重持久化保障
storage_context.persist(persist_dir=Config.PERSIST_DIR)
index.storage_context.persist(persist_dir=Config.PERSIST_DIR) # <-- 新增################索引加载 #############
print("加载已有索引...")
storage_context = StorageContext.from_defaults(persist_dir=Config.PERSIST_DIR,vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
)
index = VectorStoreIndex.from_vector_store(storage_context.vector_store,storage_context=storage_context,embed_model=Settings.embed_model
)
5. 创建检索器和相应合成器
from llama_index.core.postprocessor import SentenceTransformerRerank # 新增重排序组件
# 创建检索器和响应合成器(修改部分)
retriever = index.as_retriever(similarity_top_k=Config.TOP_K # 扩大初始检索数量
)
response_synthesizer = get_response_synthesizer(text_qa_template=response_template,verbose=True
)
# 检索
initial_nodes = retriever.retrieve(question)
# 结果重排序
for node in initial_nodes:node.node.metadata['initial_score'] = node.score # 保存初始分数到元数据
# 2. 重排序
# 初始化重排序器(新增)
reranker = SentenceTransformerRerank(model=Config.RERANK_MODEL_PATH,top_n=Config.RERANK_TOP_K
)
reranked_nodes = reranker.postprocess_nodes(initial_nodes, query_str=question)
# 3. 合成答案
response = response_synthesizer.synthesize(question, nodes=reranked_nodes
)
6.配置模型
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# Embedding模型
embed_model = HuggingFaceEmbedding(model_name=Config.EMBED_MODEL_PATH,
)# LLM
# 1直接使用本地的模型文件
llm = HuggingFaceLLM(model_name=Config.LLM_MODEL_PATH,tokenizer_name=Config.LLM_MODEL_PATH,model_kwargs={"trust_remote_code": True,},tokenizer_kwargs={"trust_remote_code": True},generate_kwargs={"temperature": 0.3}
)
#2 使用模型的openai 接口
llm = OpenAILike(model="/home/cw/llms/Qwen/Qwen1.5-1.8B-Chat",api_base="http://localhost:8000/v1",api_key="fake",context_window=4096,is_chat_model=True,is_function_calling_model=False,)
Settings.embed_model = embed_model
Settings.llm = llm
7. 完整代码
# -*- coding: utf-8 -*-
import json
import time
from pathlib import Path
from typing import List, Dictimport chromadb
from llama_index.core import VectorStoreIndex, StorageContext, Settings, get_response_synthesizer
from llama_index.core.schema import TextNode
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import PromptTemplate
from llama_index.core.postprocessor import SentenceTransformerRerank # 新增重排序组件QA_TEMPLATE = ("<|im_start|>system\n""您是中国劳动法领域专业助手,必须严格遵循以下规则:\n""1.仅使用提供的法律条文回答问题\n""2.若问题与劳动法无关或超出知识库范围,明确告知无法回答\n""3.引用条文时标注出处\n\n""可用法律条文(共{context_count}条):\n{context_str}\n<|im_end|>\n""<|im_start|>user\n问题:{query_str}<|im_end|>\n""<|im_start|>assistant\n"
)response_template = PromptTemplate(QA_TEMPLATE)# ================== 配置区 ==================
class Config:EMBED_MODEL_PATH = r"/home/cw/llms/embedding_model/sungw111/text2vec-base-chinese-sentence"RERANK_MODEL_PATH = r"/home/cw/llms/rerank_model/BAAI/bge-reranker-large" # 新增重排序模型路径LLM_MODEL_PATH = r"/home/cw/llms/Qwen/Qwen1___5-1___8B-Chat"DATA_DIR = "./data"VECTOR_DB_DIR = "./chroma_db"PERSIST_DIR = "./storage"COLLECTION_NAME = "chinese_labor_laws"TOP_K = 10 # 扩大初始检索数量RERANK_TOP_K = 3 # 重排序后保留数量# ================== 初始化模型 ==================
def init_models():"""初始化模型并验证"""# Embedding模型embed_model = HuggingFaceEmbedding(model_name=Config.EMBED_MODEL_PATH,)# LLMllm = HuggingFaceLLM(model_name=Config.LLM_MODEL_PATH,tokenizer_name=Config.LLM_MODEL_PATH,model_kwargs={"trust_remote_code": True,},tokenizer_kwargs={"trust_remote_code": True},generate_kwargs={"temperature": 0.3})# 初始化重排序器(新增)reranker = SentenceTransformerRerank(model=Config.RERANK_MODEL_PATH,top_n=Config.RERANK_TOP_K)Settings.embed_model = embed_modelSettings.llm = llm# 验证模型test_embedding = embed_model.get_text_embedding("测试文本")print(f"Embedding维度验证:{len(test_embedding)}")return embed_model, llm, reranker # 返回重排序器# ================== 向量存储 ==================def init_vector_store(nodes: List[TextNode]) -> VectorStoreIndex:chroma_client = chromadb.PersistentClient(path=Config.VECTOR_DB_DIR)chroma_collection = chroma_client.get_or_create_collection(name=Config.COLLECTION_NAME,metadata={"hnsw:space": "cosine"})# 确保存储上下文正确初始化storage_context = StorageContext.from_defaults(vector_store=ChromaVectorStore(chroma_collection=chroma_collection))# 判断是否需要新建索引if chroma_collection.count() == 0 and nodes is not None:print(f"创建新索引({len(nodes)}个节点)...")# 显式将节点添加到存储上下文storage_context.docstore.add_documents(nodes) index = VectorStoreIndex(nodes,storage_context=storage_context,show_progress=True)# 双重持久化保障storage_context.persist(persist_dir=Config.PERSIST_DIR)index.storage_context.persist(persist_dir=Config.PERSIST_DIR) # <-- 新增else:print("加载已有索引...")storage_context = StorageContext.from_defaults(persist_dir=Config.PERSIST_DIR,vector_store=ChromaVectorStore(chroma_collection=chroma_collection))index = VectorStoreIndex.from_vector_store(storage_context.vector_store,storage_context=storage_context,embed_model=Settings.embed_model)# 安全验证print("\n存储验证结果:")doc_count = len(storage_context.docstore.docs)print(f"DocStore记录数:{doc_count}")if doc_count > 0:sample_key = next(iter(storage_context.docstore.docs.keys()))print(f"示例节点ID:{sample_key}")else:print("警告:文档存储为空,请检查节点添加逻辑!")return index# ================== 主程序 ==================
def main():embed_model, llm, reranker = init_models() # 获取重排序器# 仅当需要更新数据时执行if not Path(Config.VECTOR_DB_DIR).exists():print("\n初始化数据...")raw_data = load_and_validate_json_files(Config.DATA_DIR)nodes = create_nodes(raw_data)else:nodes = Noneprint("\n初始化向量存储...")start_time = time.time()index = init_vector_store(nodes)print(f"索引加载耗时:{time.time()-start_time:.2f}s")# 创建检索器和响应合成器(修改部分)retriever = index.as_retriever(similarity_top_k=Config.TOP_K # 扩大初始检索数量)response_synthesizer = get_response_synthesizer(text_qa_template=response_template,verbose=True)# 示例查询while True:question = input("\n请输入劳动法相关问题(输入q退出): ")if question.lower() == 'q':break# 执行检索-重排序-回答流程(新增重排序步骤)start_time = time.time()# 1. 初始检索initial_nodes = retriever.retrieve(question)retrieval_time = time.time() - start_timefor node in initial_nodes:node.node.metadata['initial_score'] = node.score # 保存初始分数到元数据# 2. 重排序reranked_nodes = reranker.postprocess_nodes(initial_nodes, query_str=question)rerank_time = time.time() - start_time - retrieval_time# 3. 合成答案response = response_synthesizer.synthesize(question, nodes=reranked_nodes)synthesis_time = time.time() - start_time - retrieval_time - rerank_time# 显示结果(修改显示逻辑)print(f"\n智能助手回答:\n{response.response}")print("\n支持依据:")for idx, node in enumerate(reranked_nodes, 1):# 兼容新版API的分数获取方式initial_score = node.metadata.get('initial_score', node.score) # 获取初始分数rerank_score = node.score # 重排序后的分数meta = node.node.metadataprint(f"\n[{idx}] {meta['full_title']}")print(f" 来源文件:{meta['source_file']}")print(f" 法律名称:{meta['law_name']}")print(f" 初始相关度:{node.node.metadata['initial_score']:.4f}") print(f" 重排序得分:{node.score:.4f}")print(f" 条款内容:{node.node.text[:100]}...")print(f"\n[性能分析] 检索: {retrieval_time:.2f}s | 重排序: {rerank_time:.2f}s | 合成: {synthesis_time:.2f}s")if __name__ == "__main__":main()