当前位置: 首页 > ai >正文

智能土木通 - 土木工程专业知识问答系统02-RAG检索模块搭建

一、项目目录

civil_qa_system/
├── docs/                    # 项目文档
├── config/                  # 配置文件
├── core/                    # 核心功能代码
├── knowledge_base/          # 知识库相关
├── web/                     # Web应用部分
├── cli/                     # 命令行工具
├── tests/                   # 测试代码
├── scripts/                 # 辅助脚本
├── requirements/            # 依赖管理
└── README.md                # 项目说明

二、命名规范

  • 类名:使用大驼峰命名法,例如:MyClass
  • 函数名:使用小驼峰命名法,例如:my_function
  • 变量名:使用小驼峰命名法,例如:my_variable
  • 文件夹:使用小驼峰命名法。

 三、连接大模型

在core文件下创建core/llm/qwen_client.py,这个文件是集中管理大模型相关代码。

我这里使用的是通义千问大模型,当然你也可以选用别的大模型,但是代码和配置要改一下

在core文件下新建:conf/.qwen

里面配置大模型的api-key,将下面的your_api_key_here换成你自己的api-key。没有使用过的去阿里云中申请:申请地址

# 通义千问API配置
DASHSCOPE_API_KEY=your_api_key_here

需要安装:在终端中使用pip安装就行
langchain-community>=0.0.28
python-dotenv>=1.0.0
dashscope>=1.14.0

from dotenv import load_dotenv
import os
from typing import Tuple
from langchain_community.llms.tongyi import Tongyi
from langchain_community.chat_models import ChatTongyi
from langchain_community.embeddings import DashScopeEmbeddings
import logging# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)def load_qwen_config() -> bool:"""加载千问环境变量配置Returns:bool: 是否加载成功"""try:current_dir = os.path.dirname(__file__)conf_file_path_qwen = os.path.join(current_dir, '..', 'conf', '.qwen')if not os.path.exists(conf_file_path_qwen):logger.error(f"Qwen config file not found at: {conf_file_path_qwen}")return Falseload_dotenv(dotenv_path=conf_file_path_qwen)return Trueexcept Exception as e:logger.exception("Failed to load Qwen configuration")return Falsedef get_qwen_models() -> Tuple[Tongyi, ChatTongyi, DashScopeEmbeddings]:"""初始化并返回千问系列大模型组件Returns:Tuple: (llm, chat, embed) 三元组Raises:RuntimeError: 当配置加载失败或初始化失败时抛出"""if not load_qwen_config():raise RuntimeError("Qwen configuration loading failed")try:# 初始化LLMllm = Tongyi(model="qwen-max",temperature=0.1,top_p=0.7,max_tokens=1024,verbose=True)# 初始化Chat模型chat = ChatTongyi(model="qwen-max",temperature=0.01,top_p=0.2,max_tokens=1024)# 初始化Embedding模型embed = DashScopeEmbeddings(model="text-embedding-v3")logger.info("Qwen models initialized successfully")return llm, chat, embedexcept Exception as e:logger.exception("Failed to initialize Qwen models")raise RuntimeError(f"Model initialization failed: {str(e)}")

在写代码中,我们要遵守没写一个小模块都要测试的习惯,在tests中创建unit/test_qwen_client.py

编写测试代码,上面我们写的是大模型的连接,那么在测试中就要试试能不能连接

使用pip安装测试工具pytest,不用调用函数也能测试了。

import pytest
from core.llm.qwen_client import get_qwen_modelsclass TestQwenClient:def test_model_initialization(self):"""测试模型是否能成功初始化"""llm, chat, embed = get_qwen_models()assert llm is not Noneassert chat is not Noneassert embed is not Nonereturn llm, chat, embeddef test_invalid_config(self, monkeypatch):"""测试配置错误情况"""monkeypatch.setenv("DASHSCOPE_API_KEY", "")with pytest.raises(RuntimeError):get_qwen_models()

 运行测试文件会返回:

Testing started at 15:40 ...
Launching pytest with arguments test_qwen_client.py::TestQwenClient::test_invalid_config --no-header --no-summary -q in D:\construction_QA_system\tests\unit============================= test session starts =============================
collecting ... collected 1 itemtest_qwen_client.py::TestQwenClient::test_invalid_config PASSED          [100%]============================== 1 passed in 0.85s ==============================

表示测试通过

 为了防止大家弄错,我这里贴一下我自己的项目目录:

四、实现向量库类

这里使用chroma,如果没有了解过的可以去哔哩哔哩或者官网上看看

pip安装:

chromadb>=0.4.15
langchain-chroma>=0.0.4

在knowledge_base中新建storage/chroma_manager.py模块,这里面编写chroma的创建向量库,向向量库中添加文档和查询文档功能

from typing import Optional, List, Union
import chromadb
from chromadb import Settings
from langchain_chroma import Chroma
from langchain_core.embeddings import Embeddings
from langchain_core.documents import Document
import loggingclass ChromaManager:"""ChromaDB向量数据库的高级封装管理类特性:- 支持本地和HTTP两种连接模式- 自动持久化管理- 线程安全连接- 完善的错误处理"""def __init__(self,chroma_server_type: str = "local",host: str = "localhost",port: int = 8000,persist_path: str = "chroma_db",collection_name: str = "langchain",embed_model: Optional[Embeddings] = None):"""初始化ChromaDB连接Args:chroma_server_type: 连接类型 ("local"|"http")host: 服务器地址 (HTTP模式必需)port: 服务器端口 (HTTP模式必需)persist_path: 本地持久化路径 (本地模式必需)collection_name: 集合名称embed_model: 嵌入模型实例"""self._validate_init_params(chroma_server_type, host, port, persist_path)self.client = self._create_client(chroma_server_type, host, port, persist_path)self.collection_name = collection_nameself.embed_model = embed_modelself.logger = logging.getLogger(__name__)try:self.store = Chroma(collection_name=collection_name,embedding_function=embed_model,client=self.client,persist_directory=persist_path if chroma_server_type == "local" else None)self.logger.info(f"ChromaDB initialized successfully. Mode: {chroma_server_type}")except Exception as e:self.logger.error(f"ChromaDB initialization failed: {str(e)}")raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")def _validate_init_params(self, server_type: str, host: str, port: int, path: str):"""参数验证"""if server_type not in ["local", "http"]:raise ValueError(f"Invalid server type: {server_type}. Must be 'local' or 'http'")if server_type == "http" and not all([host, port]):raise ValueError("Host and port must be specified for HTTP mode")if server_type == "local" and not path:raise ValueError("Persist path must be specified for local mode")def _create_client(self, server_type: str, host: str, port: int, path: str) -> chromadb.Client:"""创建Chroma客户端"""try:if server_type == "http":return chromadb.HttpClient(host=host,port=port,settings=Settings(allow_reset=True))else:return chromadb.PersistentClient(path=path,settings=Settings(anonymized_telemetry=False,allow_reset=True))except Exception as e:logging.error(f"Chroma client creation failed: {str(e)}")raisedef add_documents(self, docs: Union[List[Document], List[str]]) -> List[str]:"""添加文档到集合Args:docs: 文档列表,可以是Document对象或纯文本Returns:插入文档的ID列表"""try:if not docs:self.logger.warning("Attempted to add empty documents list")return []doc_ids = self.store.add_documents(documents=docs)self.logger.info(f"Added {len(doc_ids)} documents to collection '{self.collection_name}'")return doc_idsexcept Exception as e:self.logger.error(f"Failed to add documents: {str(e)}")raise RuntimeError(f"Document addition failed: {str(e)}")def query(self, query_text: str, k: int = 5, ​**kwargs) -> List[Document]:"""相似性查询Args:query_text: 查询文本k: 返回结果数量​**kwargs: 额外查询参数Returns:匹配的文档列表"""try:results = self.store.similarity_search(query_text, k=k, ​**kwargs)self.logger.debug(f"Query returned {len(results)} results for: {query_text}")return resultsexcept Exception as e:self.logger.error(f"Query failed: {str(e)}")raise RuntimeError(f"Query operation failed: {str(e)}")def get_collection_stats(self) -> dict:"""获取集合统计信息"""try:collection = self.client.get_collection(self.collection_name)return {"count": collection.count(),"metadata": collection.metadata}except Exception as e:self.logger.error(f"Failed to get collection stats: {str(e)}")raise RuntimeError(f"Collection stats retrieval failed: {str(e)}")@propertydef store(self) -> Chroma:"""获取LangChain Chroma实例"""return self._store@store.setterdef store(self, value):self._store = value

 测试一下这个模块的功能

import pytest
from unittest.mock import MagicMock
from knowledge_base.storage.chroma_manager import ChromaManagerclass TestChromaManager:@pytest.fixturedef mock_embedding(self):mock = MagicMock()mock.embed_documents.return_value = [[0.1]*768]return mockdef test_local_init(self, tmp_path, mock_embedding):"""测试本地模式初始化"""db = ChromaManager(chroma_server_type="local",persist_path=str(tmp_path),embed_model=mock_embedding)assert db.store is not Nonedef test_add_documents(self, tmp_path, mock_embedding):"""测试文档添加功能"""db = ChromaManager(chroma_server_type="local",persist_path=str(tmp_path),embed_model=mock_embedding)test_docs = ["Test document 1", "Test document 2"]doc_ids = db.add_documents(test_docs)assert len(doc_ids) == 2

执行结果:

Testing started at 17:00 ...
Launching pytest with arguments test_chroma_manager.py::TestChromaManager::test_local_init --no-header --no-summary -q in D:\construction_QA_system\tests\unit============================= test session starts =============================
collecting ... collected 1 itemtest_chroma_manager.py::TestChromaManager::test_local_init ============================== 1 passed in 1.79s ==============================
PASSED        [100%]
进程已结束,退出代码为 0

五、实现入库功能

在knowledge_base文件夹下新建builders/pdf_processor,其中实现类PDFProcessor,主要功能:

- 从指定目录加载PDF文件
- 提取文本内容
- 分块处理文本
- 将文本块存入向量数据库

 

import os
import logging
import time
from tqdm import tqdm
from typing import List, Optional
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
# 修改导入路径为新的项目结构
from knowledge_base.storage.chroma_manager import ChromaManagerclass PDFProcessor:"""PDF文档处理管道,负责:- 从指定目录加载PDF文件- 提取文本内容- 分块处理文本- 将文本块存入向量数据库参数说明:directory: PDF文件所在目录路径chroma_server_type: ChromaDB服务器类型("local"或"http")persist_path: ChromaDB持久化存储路径(本地模式使用)embed: 文本嵌入模型实例file_group_num: 每组处理的文件数(默认80)batch_num: 每次插入的批次数量(默认6)chunksize: 文本分块大小(默认500字符)overlap: 分块重叠大小(默认100字符)"""def __init__(self,directory: str,chroma_server_type: str = "local",persist_path: str = "chroma_db",embedding_function: Optional[object] = None,file_group_num: int = 80,batch_num: int = 6,chunksize: int = 500,overlap: int = 100):# 参数初始化self.directory = directoryself.file_group_num = file_group_numself.batch_num = batch_numself.chunksize = chunksizeself.overlap = overlap# 初始化ChromaDB连接(更新类名)self.chroma_db = ChromaManager(chroma_server_type=chroma_server_type,persist_path=persist_path,embedding_function=embedding_function)# 配置日志系统(日志文件路径调整为相对路径)self._setup_logging()# 验证目录存在if not os.path.isdir(self.directory):raise ValueError(f"指定目录不存在: {self.directory}")def _setup_logging(self):"""配置日志系统"""log_dir = "logs"os.makedirs(log_dir, exist_ok=True)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S',handlers=[logging.FileHandler(os.path.join(log_dir, "pdf_processor.log")),logging.StreamHandler()])self.logger = logging.getLogger(__name__)def load_pdf_files(self) -> List[str]:"""扫描目录并返回所有PDF文件路径返回:包含完整PDF文件路径的列表异常:ValueError: 如果目录中没有PDF文件"""pdf_files = []for file in os.listdir(self.directory):if file.lower().endswith('.pdf'):pdf_files.append(os.path.join(self.directory, file))if not pdf_files:raise ValueError(f"目录中没有找到PDF文件: {self.directory}")self.logger.info(f"发现 {len(pdf_files)} 个PDF文件")return pdf_filesdef load_pdf_content(self, pdf_path: str) -> List[Document]:"""使用PyMuPDF加载单个PDF文件内容参数:pdf_path: PDF文件路径返回:LangChain Document对象列表异常:RuntimeError: 如果文件加载失败"""try:loader = PyMuPDFLoader(file_path=pdf_path)docs = loader.load()self.logger.debug(f"成功加载: {pdf_path} (共 {len(docs)} 页)")return docsexcept Exception as e:self.logger.error(f"加载PDF失败 {pdf_path}: {str(e)}")raise RuntimeError(f"无法加载PDF文件: {pdf_path}")def split_text(self, documents: List[Document]) -> List[Document]:"""使用递归字符分割器将文档分块参数:documents: 待分割的Document列表返回:分割后的Document列表"""text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunksize,chunk_overlap=self.overlap,length_function=len,add_start_index=True,separators=["\n\n", "\n", "。", "!", "?", ";", " ", ""]  # 中文友好分割符)try:docs = text_splitter.split_documents(documents)self.logger.info(f"文本分割完成: 原始 {len(documents)} 块 → 分割后 {len(docs)} 块")return docsexcept Exception as e:self.logger.error(f"文本分割失败: {str(e)}")raise RuntimeError("文本分割过程中发生错误")def insert_docs_chromadb(self, docs: List[Document], batch_size: int = 6) -> None:"""将文档分批插入ChromaDB,带进度条和性能监控"""if not docs:self.logger.warning("尝试插入空文档列表")returnself.logger.info(f"开始插入 {len(docs)} 个文档到ChromaDB")start_time = time.time()total_docs_inserted = 0total_batches = (len(docs) + batch_size - 1) // batch_sizetry:with tqdm(total=total_batches, desc="插入进度", unit="batch") as pbar:for i in range(0, len(docs), batch_size):batch = docs[i:i + batch_size]# 更新方法调用(原add_with_langchain改为更标准的方法名)self.chroma_db.add_documents(batch)total_docs_inserted += len(batch)# 计算吞吐量(每分钟处理文档数)elapsed_time = time.time() - start_timetpm = (total_docs_inserted / elapsed_time) * 60 if elapsed_time > 0 else 0# 更新进度条pbar.set_postfix({"TPM": f"{tpm:.2f}","文档数": total_docs_inserted})pbar.update(1)self.logger.info(f"文档插入完成! 总耗时: {time.time() - start_time:.2f}秒")except Exception as e:self.logger.error(f"文档插入失败: {str(e)}")raise RuntimeError(f"文档插入失败: {str(e)}")def process_pdfs_group(self, pdf_files_group: List[str]) -> None:"""处理一组PDF文件(读取→分割→存储)参数:pdf_files_group: PDF文件路径列表"""try:# 阶段1: 加载所有PDF内容pdf_contents = []for pdf_path in pdf_files_group:documents = self.load_pdf_content(pdf_path)pdf_contents.extend(documents)# 阶段2: 文本分割if pdf_contents:docs = self.split_text(pdf_contents)# 阶段3: 存储到向量数据库if docs:self.insert_docs_chromadb(docs, self.batch_num)except Exception as e:self.logger.error(f"处理PDF组失败: {str(e)}")# 可以选择继续处理下一组而不是终止# raisedef process_pdfs(self) -> None:"""主处理流程: 扫描目录→分组处理所有PDF文件"""self.logger.info("=== 开始PDF处理流程 ===")start_time = time.time()try:pdf_files = self.load_pdf_files()# 分组处理PDF文件for i in range(0, len(pdf_files), self.file_group_num):group = pdf_files[i:i + self.file_group_num]self.logger.info(f"正在处理文件组 {i // self.file_group_num + 1}/{(len(pdf_files) - 1) // self.file_group_num + 1}")self.process_pdfs_group(group)self.logger.info(f"=== 处理完成! 总耗时: {time.time() - start_time:.2f}秒 ===")print("PDF处理成功完成!")except Exception as e:self.logger.error(f"PDF处理流程失败: {str(e)}")raise RuntimeError(f"PDF处理失败: {str(e)}")

测试一下这个模块

在tests文件夹下新建unit /test_pdf_processor.py

import os
import pytest
from knowledge_base.builders.pdf_processor import PDFProcessor@pytest.fixture
def test_resources(tmp_path):"""测试资源准备"""# 创建PDF测试目录pdf_dir = tmp_path / "pdfs"pdf_dir.mkdir()# 复制预制PDF(或动态生成)test_pdf = os.path.join(os.path.dirname(__file__), "test_files", "sample.pdf")target_pdf = pdf_dir / "test.pdf"with open(test_pdf, "rb") as src, open(target_pdf, "wb") as dst:dst.write(src.read())return {"pdf_dir": str(pdf_dir),"db_dir": str(tmp_path / "chroma_db"),"pdf_path": str(target_pdf)}def test_pdf_processing(test_resources):processor = PDFProcessor(directory=test_resources["pdf_dir"],persist_path=test_resources["db_dir"])processor.process_pdfs()# 验证数据库assert os.path.exists(test_resources["db_dir"])assert any(os.listdir(test_resources["db_dir"]))

运行后效果:

Testing started at 19:35 ...
Launching pytest with arguments D:\construction_QA_system\tests\unit\test_pdf_processor.py --no-header --no-summary -q in D:\construction_QA_system\tests\unit============================= test session starts =============================
collecting ... collected 1 itemtest_pdf_processor.py::test_pdf_processing PASSED                        [100%]PDF处理成功完成!============================== 1 passed in 2.78s ==============================进程已结束,退出代码为 0

七、向量检索模块

在文件knowledge_base中新建retrieval/vector_retriever.py

from typing import List, Dict, Optional
from langchain_core.documents import Document
from ..storage.chroma_manager import ChromaManagerclass VectorRetriever:def __init__(self,chroma_server_type: str = "local",persist_path: str = "chroma_db",collection_name: str = "construction_docs",embedding_function: Optional[object] = None,top_k: int = 5):"""向量检索器Args:top_k: 返回最相关的K个结果score_threshold: 相似度阈值"""self.chroma_db = ChromaManager(chroma_server_type=chroma_server_type,persist_path=persist_path,collection_name=collection_name,embedding_function=embedding_function)self.top_k = top_kdef similarity_search(self,query: str,filter_conditions: Optional[Dict] = None) -> List[Document]:"""相似度搜索"""collection = self.chroma_db.collection# 获取查询向量(假设embedding_function已配置)query_embedding = self.chroma_db.embedding_function.embed_query(query)# 执行查询results = collection.query(query_embeddings=[query_embedding],n_results=self.top_k,where=filter_conditions)# 转换为Document对象docs = []for i in range(len(results['ids'][0])):doc = Document(page_content=results['documents'][0][i],metadata=results['metadatas'][0][i] or {})docs.append(doc)return docsdef hybrid_search(self,query: str,keyword: Optional[str] = None,filter_conditions: Optional[Dict] = None) -> List[Document]:"""混合检索(向量+关键词)"""# 先执行向量搜索vector_results = self.similarity_search(query, filter_conditions)# 如果有关键词,进行过滤if keyword:filtered = [doc for doc in vector_resultsif keyword.lower() in doc.page_content.lower()]return filtered[:self.top_k]return vector_resultsdef get_by_id(self, doc_id: str) -> Optional[Document]:"""根据ID获取文档"""result = self.chroma_db.collection.get(ids=[doc_id])if not result['documents']:return Nonereturn Document(page_content=result['documents'][0],metadata=result['metadatas'][0] or {})

FastAPI接口封装

 跟目录下新建api/retrieval_api.py

import logging
from typing import List, Optional
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from datetime import datetime# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)app = FastAPI(title="Construction QA Retrieval API",description="建筑工程知识库检索接口",version="1.0.0",openapi_tags=[{"name": "检索","description": "知识库检索相关接口"}]
)# 允许跨域
app.add_middleware(CORSMiddleware,allow_origins=["*"],allow_credentials=True,allow_methods=["*"],allow_headers=["*"],
)# --- 数据模型 ---
class DocumentMetadata(BaseModel):"""文档元数据模型"""source: Optional[str] = Field(None, example="GB/T 50081-2019")page: Optional[int] = Field(None, example=12)timestamp: Optional[datetime] = Field(None, example="2023-01-01T00:00:00")class DocumentResponse(BaseModel):"""检索结果模型"""id: str = Field(..., example="doc_123")content: str = Field(..., example="混凝土强度检测标准...")metadata: DocumentMetadatascore: float = Field(..., ge=0, le=1, example=0.85)class QueryRequest(BaseModel):"""查询请求模型"""query: str = Field(..., min_length=1, example="混凝土强度标准")top_k: Optional[int] = Field(5, gt=0, le=20, example=3)keyword_filter: Optional[str] = Field(None, example="钢筋")metadata_filter: Optional[dict] = Field(None, example={"source": "GB"})class HealthCheckResponse(BaseModel):"""健康检查响应"""status: str = Field(..., example="OK")version: str = Field(..., example="1.0.0")# --- 核心逻辑 ---
def initialize_retriever():"""初始化检索器(实际项目应使用依赖注入)"""from knowledge_base.retrieval.vector_retriever import VectorRetrieverfrom core.utils.embedding_utils import load_embedding_modeltry:return VectorRetriever(persist_path="data/vector_db",embedding_function=load_embedding_model(),top_k=10)except Exception as e:logger.error(f"检索器初始化失败: {str(e)}")raiseretriever = initialize_retriever()# --- API端点 ---
@app.get("/health", response_model=HealthCheckResponse, tags=["系统"])
async def health_check():"""服务健康检查"""return {"status": "OK","version": "1.0.0"}@app.post("/search",response_model=List[DocumentResponse],tags=["检索"],summary="文档检索",responses={200: {"description": "成功返回检索结果"},400: {"description": "无效请求参数"},500: {"description": "服务器内部错误"}})
async def search_documents(request: QueryRequest):"""执行文档检索,支持以下方式:- 纯向量检索- 关键词过滤检索- 元数据过滤检索"""try:logger.info(f"收到检索请求: {request.dict()}")# 参数验证if len(request.query) > 500:raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,detail="查询文本过长(最大500字符)")# 执行检索if request.keyword_filter or request.metadata_filter:docs = retriever.hybrid_search(query=request.query,keyword=request.keyword_filter,filter_conditions=request.metadata_filter)else:docs = retriever.similarity_search(query=request.query,filter_conditions=request.metadata_filter)# 格式化结果results = []for doc in docs[:request.top_k]:if not hasattr(doc, 'metadata'):doc.metadata = {}results.append({"id": str(hash(doc.page_content)),"content": doc.page_content,"metadata": doc.metadata,"score": doc.metadata.get("score", 0.0)})logger.info(f"返回 {len(results)} 条结果")return resultsexcept HTTPException:raiseexcept Exception as e:logger.error(f"检索失败: {str(e)}", exc_info=True)raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,detail="检索服务暂时不可用")@app.get("/document/{doc_id}",response_model=DocumentResponse,tags=["检索"],summary="按ID获取文档")
async def get_document(doc_id: str):"""通过文档ID获取完整内容"""try:doc = retriever.get_by_id(doc_id)if not doc:raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,detail="文档不存在")return {"id": doc_id,"content": doc.page_content,"metadata": doc.metadata or {},"score": 1.0}except Exception as e:logger.error(f"获取文档失败: {str(e)}")raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,detail="文档获取失败")# --- 启动配置 ---
if __name__ == "__main__":import uvicornuvicorn.run(app,host="0.0.0.0",port=8000,log_config={"version": 1,"disable_existing_loggers": False,"handlers": {"console": {"class": "logging.StreamHandler","level": "INFO","formatter": "default"}},"formatters": {"default": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}},"root": {"handlers": ["console"],"level": "INFO"}})

测试用例:tests/unit/test_embedding_utils.py

import pytest
from unittest.mock import patch
from core.utils.embedding_utils import load_embedding_modelclass TestEmbeddingUtils:@patch('langchain.embeddings.HuggingFaceEmbeddings')def test_load_huggingface(self, mock_embeddings):"""测试加载HuggingFace模型"""model = load_embedding_model(model_type="huggingface")mock_embeddings.assert_called_once()@patch.dict('os.environ', {'OPENAI_API_KEY': 'test_key'})@patch('langchain.embeddings.OpenAIEmbeddings')def test_load_openai(self, mock_embeddings):"""测试加载OpenAI模型"""model = load_embedding_model(model_type="openai")mock_embeddings.assert_called_once_with(model="text-embedding-3-small",deployment=None,openai_api_key="test_key")def test_invalid_model_type(self):"""测试无效模型类型"""with pytest.raises(ValueError):load_embedding_model(model_type="invalid_type")

 

http://www.xdnf.cn/news/14391.html

相关文章:

  • AC耦合与DC耦合
  • 体验AI智能投资!AI Hedge Fund了解一下
  • Java可变参数方法的常见错误与最佳实践
  • hyper-v虚拟机使用双屏
  • iOS —— UI(2)
  • Spring Cloud 所有组件全面总结
  • 「AI大数据」| 智慧公路大数据运营中心解决方案
  • Java类加载器与双亲委派模型深度解析
  • DNS递归查询
  • BOLL指标
  • Oracle21cR3之客户端安装错误及处理方法
  • 第11章 结构 笔记
  • 华为OD-2024年E卷-小明周末爬山[200分] -- python
  • 亚马逊ASIN: B0DNTQ2YNT数据深度解析报告
  • 3.创建数据库
  • STM32103CBT6显示ST7789通过SPI方式显示柬埔寨文
  • Unity Addressable使用之入门篇
  • 讲一下进程和线程
  • Day54打卡 @浙大疏锦行
  • 37-Oracle 23 ai Shrink Tablespace(一键收缩表空间)
  • Composer 的 PHP 依赖库提交教程
  • 【Qt】Qt 基础
  • Redis-CPP通用接口
  • Leetcode 3584. Maximum Product of First and Last Elements of a Subsequence
  • 139. 单词拆分
  • (LeetCode 每日一题) 1432. 改变一个整数能得到的最大差值(贪心)
  • React组件通信——context(提供者/消费者)
  • MySQL常用函数详解之字符串函数
  • nohz_full 参数对内核软硬锁检测机制的影响分析
  • 嵌入式学习笔记 - SH79F6441 堆栈栈顶可以是片上内部RAM(00H-FFH)的任意地址怎么理解