当前位置: 首页 > news >正文

关于模型记忆力的实现方式

问题描述

在对法律问题答疑功能进行调研的过程中我们发现,当用户决定使用AI来研究问题时大多数情况会进行两次以上的提问。但是模型生成回答的原理决定了上一次的对话与下一次的对话结果是弱相关的,因此决定使用保存历史记录的方式来实现记忆力功能。

方案:

将近10条历史对话(问题 + 回答)拼接成上下文。

附加在新的用户输入之前,一并发送给语言模型。

语言模型借助上下文内容,实现多轮连续对话的推理能力

具体实现:

数据结构设计

message:

class Message(BaseModel):role: str  content: str

role:用于向模型说明该content由谁发出。

conversation:

class Conversation(BaseModel):id: str  # 会话IDmessages: List[Message] = []

id:对不同的会话记录使用id来标识。

message:是一个message对象的列表,用来存储历史记录。

chatrequest:

class ChatRequest(BaseModel):conversation_id: str = None  message: str  max_turns: int = 10 

conversation_id:历史记录识别代码。

message:最新的一条消息。

Conversation_Manager:

class ConversationManager:def __init__(self):self.conversations: Dict[str, Conversation] = {}def create_conversation(self) -> Conversation:conv_id = str(uuid4())now = time.time()conversation = Conversation(id=conv_id,created_at=now,updated_at=now)self.conversations[conv_id] = conversationreturn conversationdef get_conversation(self, conv_id: str) -> Conversation:return self.conversations.get(conv_id)def add_message(self, conv_id: str, message: Message):if conv_id in self.conversations:self.conversations[conv_id].messages.append(message)self.conversations[conv_id].updated_at = time.time()def prune_messages(self, conv_id: str, max_turns: int):"""修剪过长的对话历史"""if conv_id in self.conversations:messages = self.conversations[conv_id].messagesif len(messages) > max_turns * 2:  # 用户和AI各算一轮self.conversations[conv_id].messages = messages[-max_turns*2:]conversation_manager = ConversationManager()

conversations:一个字典,conversation_id对应conversation。

方法解释:

create_conversation(self): 创建一个conversation对象并为其分配conversation_id\时间戳。然后将其存入conversations.

get_conversation(self,conv_id:str):通过conversation_id找到对应的conversation。

add_message(self,conv_id:str,message:Message): 通过conv_id找到对应的conversation后插入message。

prune_message(self,conv_id:str,max_turns:int):历史消息的剪切功能,保留最近的max_turn轮对话。

Prompt设计:

【法律咨询任务指令】
你是一位资深法律专家,根据以下历史对话回答问题:
- 历史对话:{{ history }}
- 对话要求:* 回答内容必须与法律相关* 你的回复内容长度必须小于等于200字请生成符合上述要求的法律回答:

接口设计:

@app.post("/chat")
async def chat(request: ChatRequest):# 获取或创建会话if request.conversation_id and conversation_manager.get_conversation(request.conversation_id):print("old \n")conversation = conversation_manager.get_conversation(request.conversation_id)else:print("new \n")conversation = conversation_manager.create_conversation()# 添加用户消息user_msg = Message(role="user", content=request.message)conversation_manager.add_message(conversation.id, user_msg)text=""for message in conversation.messages:text+=message.role+" : "+message.content+"\n"print(text)# 构建promptprompt = prompt_engine.render("chat",{"history": text,})# 调用模型生成回复inputs = tokenizer(prompt, return_tensors="pt").to(model.device)outputs = qa_model.generate(input_ids=inputs["input_ids"],max_new_tokens=200,max_length=2048,          num_beams=4,              # Beam Search平衡生成质量与速度early_stopping=True,      # 所有beam达到EOS时提前停止repetition_penalty=1.2,   # 抑制重复内容(法律文书需严谨)length_penalty=1.0,       # 中性长度惩罚(可根据需求调整)no_repeat_ngram_size=3,   # 避免3-gram重复do_sample=True,          # 禁用采样(确保确定性输出)temperature=0.3,          # 默认温度(配合do_sample=False无效))# 获取模型回复response = tokenizer.decode(outputs[0], skip_special_tokens=True)assistant_msg = response[len(prompt):]  # 提取新生成的部分# 添加AI回复到对话历史assistant_msg_obj = Message(role="assistant", content=assistant_msg.strip())conversation_manager.add_message(conversation.id, assistant_msg_obj)# 修剪过长的对话历史conversation_manager.prune_messages(conversation.id, request.max_turns)return {"conversation_id":conversation.id,"response": assistant_msg_obj.content}

当请求段传回的conversation_id为空时创建一个conversation否则根据conversation_id来找到对应的conversation。之后将前端用户发送的信息封装到message里对象里面,并使用add_message方法添加信息。最后将历史记录填充到prompt当中后传给挂载lora的模型生成结果,然后将结果又封装成message并将role设置为assistant。继续添加到conversation当中。检查对话历史长度。最后将对话id以及生成结果传回前端。

优化设计

当前记忆力功能的原理是使用一个类来存储conversations,因此需要考虑当用户基数过大时可能会出现爆内存的情况,另外当服务器关机时也会清空历史记录。因此解决方案可以是使用mysql设计字段conversation_id\role\content来存储历史记录;也可以是使用用户的浏览器缓存来存储历史记录。以上两种方案都可以,但是考虑到敏捷开发以及请求端的调试需求,本博客的方案也不失为一种快速实现的方法。

另外maxturns的参数设计也十分重要。当maxturns过大时,可能会导致模型输入过长响应时间慢或者根本无法处理的问题。当maxturns过小时,又会出现过早丢失用户之前的问题,导致记忆力丧失。

效果展示

http://www.xdnf.cn/news/658873.html

相关文章:

  • Linux GPIO子系统深度解析:从历史演进到实战应用
  • 使用 Pfam 和 InterProScan 进行蛋白质家族和功能域的分析
  • 第一章:MLOps/LLMOps 导论:原则、生命周期与挑战
  • 激光开卷落料线:技术革新与产业应用综述
  • PCCW Global 与银河航天在港成功完成低轨卫星测试
  • 紫光同创FPGA实现视频采集转USB2.0输出,基于CY7C68013芯片,提供PDS工程源码和技术支持和QT上位机
  • DC-DC升压
  • 【Qt】Debug版本正常运行,Release版本运行卡死
  • FreeRTOS 事件标志组详解:原理、用法与实战技巧
  • 网页模板素材网站 web前端网页制作模板
  • 如何清除浏览器启动hao点360
  • 【多智能体系统开发框架AutoGen解析与实践】
  • 初学ADC
  • 【四】频率域滤波(下)【830数字图像处理】
  • 华为OD机试真题——通信系统策略调度(用户调度问题)(2025B卷:100分)Java/python/JavaScript/C/C++/GO最佳实现
  • 算力服务器和GPU服务器之间的联系
  • C++中使用类的继承机制来定义和实现基类与派生类
  • 初始化硬盘时,选MBR还是GUID?—「小白教程」
  • Linux系统中为Qt项目封装一个udp客户端类
  • 在麒麟系统(Kylin OS)上安装`geckodriver`
  • 跳板问题(贪心算法+细节思考)
  • 中国工程咨询协会新型基础设施专业委员会成立
  • Open vSwitch笔记20250526
  • 基于python合成100X100的透明背景图片和图标
  • 十大排序算法
  • 单例模式,饿汉式,懒汉式,在java和spring中的体现
  • 从数据页角度理解B+树查询
  • Netty学习专栏(五):Netty高性能揭秘(Reactor模式与零拷贝的深度实践)
  • 华为OD机试真题——单词接龙(首字母接龙)(2025A卷:100分)Java/python/JavaScript/C/C++/GO最佳实现
  • 股指期货移仓换月技巧是什么?