关于模型记忆力的实现方式
问题描述
在对法律问题答疑功能进行调研的过程中我们发现,当用户决定使用AI来研究问题时大多数情况会进行两次以上的提问。但是模型生成回答的原理决定了上一次的对话与下一次的对话结果是弱相关的,因此决定使用保存历史记录的方式来实现记忆力功能。
方案:
将近10条历史对话(问题 + 回答)拼接成上下文。
附加在新的用户输入之前,一并发送给语言模型。
语言模型借助上下文内容,实现多轮连续对话的推理能力
具体实现:
数据结构设计
message:
class Message(BaseModel):role: str content: str
role:用于向模型说明该content由谁发出。
conversation:
class Conversation(BaseModel):id: str # 会话IDmessages: List[Message] = []
id:对不同的会话记录使用id来标识。
message:是一个message对象的列表,用来存储历史记录。
chatrequest:
class ChatRequest(BaseModel):conversation_id: str = None message: str max_turns: int = 10
conversation_id:历史记录识别代码。
message:最新的一条消息。
Conversation_Manager:
class ConversationManager:def __init__(self):self.conversations: Dict[str, Conversation] = {}def create_conversation(self) -> Conversation:conv_id = str(uuid4())now = time.time()conversation = Conversation(id=conv_id,created_at=now,updated_at=now)self.conversations[conv_id] = conversationreturn conversationdef get_conversation(self, conv_id: str) -> Conversation:return self.conversations.get(conv_id)def add_message(self, conv_id: str, message: Message):if conv_id in self.conversations:self.conversations[conv_id].messages.append(message)self.conversations[conv_id].updated_at = time.time()def prune_messages(self, conv_id: str, max_turns: int):"""修剪过长的对话历史"""if conv_id in self.conversations:messages = self.conversations[conv_id].messagesif len(messages) > max_turns * 2: # 用户和AI各算一轮self.conversations[conv_id].messages = messages[-max_turns*2:]conversation_manager = ConversationManager()
conversations:一个字典,conversation_id对应conversation。
方法解释:
create_conversation(self): 创建一个conversation对象并为其分配conversation_id\时间戳。然后将其存入conversations.
get_conversation(self,conv_id:str):通过conversation_id找到对应的conversation。
add_message(self,conv_id:str,message:Message): 通过conv_id找到对应的conversation后插入message。
prune_message(self,conv_id:str,max_turns:int):历史消息的剪切功能,保留最近的max_turn轮对话。
Prompt设计:
【法律咨询任务指令】
你是一位资深法律专家,根据以下历史对话回答问题:
- 历史对话:{{ history }}
- 对话要求:* 回答内容必须与法律相关* 你的回复内容长度必须小于等于200字请生成符合上述要求的法律回答:
接口设计:
@app.post("/chat")
async def chat(request: ChatRequest):# 获取或创建会话if request.conversation_id and conversation_manager.get_conversation(request.conversation_id):print("old \n")conversation = conversation_manager.get_conversation(request.conversation_id)else:print("new \n")conversation = conversation_manager.create_conversation()# 添加用户消息user_msg = Message(role="user", content=request.message)conversation_manager.add_message(conversation.id, user_msg)text=""for message in conversation.messages:text+=message.role+" : "+message.content+"\n"print(text)# 构建promptprompt = prompt_engine.render("chat",{"history": text,})# 调用模型生成回复inputs = tokenizer(prompt, return_tensors="pt").to(model.device)outputs = qa_model.generate(input_ids=inputs["input_ids"],max_new_tokens=200,max_length=2048, num_beams=4, # Beam Search平衡生成质量与速度early_stopping=True, # 所有beam达到EOS时提前停止repetition_penalty=1.2, # 抑制重复内容(法律文书需严谨)length_penalty=1.0, # 中性长度惩罚(可根据需求调整)no_repeat_ngram_size=3, # 避免3-gram重复do_sample=True, # 禁用采样(确保确定性输出)temperature=0.3, # 默认温度(配合do_sample=False无效))# 获取模型回复response = tokenizer.decode(outputs[0], skip_special_tokens=True)assistant_msg = response[len(prompt):] # 提取新生成的部分# 添加AI回复到对话历史assistant_msg_obj = Message(role="assistant", content=assistant_msg.strip())conversation_manager.add_message(conversation.id, assistant_msg_obj)# 修剪过长的对话历史conversation_manager.prune_messages(conversation.id, request.max_turns)return {"conversation_id":conversation.id,"response": assistant_msg_obj.content}
当请求段传回的conversation_id为空时创建一个conversation否则根据conversation_id来找到对应的conversation。之后将前端用户发送的信息封装到message里对象里面,并使用add_message方法添加信息。最后将历史记录填充到prompt当中后传给挂载lora的模型生成结果,然后将结果又封装成message并将role设置为assistant。继续添加到conversation当中。检查对话历史长度。最后将对话id以及生成结果传回前端。
优化设计
当前记忆力功能的原理是使用一个类来存储conversations,因此需要考虑当用户基数过大时可能会出现爆内存的情况,另外当服务器关机时也会清空历史记录。因此解决方案可以是使用mysql设计字段conversation_id\role\content来存储历史记录;也可以是使用用户的浏览器缓存来存储历史记录。以上两种方案都可以,但是考虑到敏捷开发以及请求端的调试需求,本博客的方案也不失为一种快速实现的方法。
另外maxturns的参数设计也十分重要。当maxturns过大时,可能会导致模型输入过长响应时间慢或者根本无法处理的问题。当maxturns过小时,又会出现过早丢失用户之前的问题,导致记忆力丧失。