【大模型记忆实战Demo】基于SpringAIAlibaba通过内存和Redis两种方式实现多轮记忆对话
文章目录
- 多轮对话记忆管理——基于Memory的对话记忆
- 基于内存存储历史对话
- 基于Redis存储历史对话
多轮对话记忆管理——基于Memory的对话记忆
Spring AI Alibaba共实现了三种方式:
- 基于内存的方式
- 基于jdbc(数据库)的方式
- 基于redis的方式
下文主要演示基于内存和redis的方式
基于内存存储历史对话
- 代码
首先定义大模型的角色,一个旅游规划师
设置增强拦截器
接着接口传入prompt和chatId
设定好唯一标识符和记忆轮数
private final ChatClient chatClient;public ChatMemoryController(ChatModel chatModel) {this.chatClient = ChatClient.builder(chatModel).defaultSystem("你是一个旅游规划师,请根据用户的需求提供旅游规划建议。").defaultAdvisors(new MessageChatMemoryAdvisor(new InMemoryChatMemory())).build();}
/*** 获取内存中的聊天内容* 根据提供的prompt和chatId,从内存中获取相关的聊天内容,并设置响应的字符编码为UTF-8。** @param prompt 用于获取聊天内容的提示信息* @param chatId 聊天的唯一标识符,用于区分不同的聊天会话* @param response HTTP响应对象,用于设置响应的字符编码* @return 返回包含聊天内容的Flux<String>对象*/@GetMapping("/in-memory")public Flux<String> memory(@RequestParam("prompt") String prompt,@RequestParam("chatId") String chatId,HttpServletResponse response) {response.setCharacterEncoding("UTF-8");return chatClient.prompt(prompt).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId).param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)).stream().content();}
- 调用结果
第一次:
我提问,想去杭州玩
- 第二轮
我提问:那有哪些好玩的地方
可以看到,第二次根据我第一次“杭州”的关键词进行了推荐,拥有了记忆
基于内存的方法存在一个缺点:如果机器重启了,记忆就消失了,因此可以采用持久化到Redis的方式
基于Redis存储历史对话
- 代码
private final ChatClient chatClient;public ChatMemoryController(ChatModel chatModel) {this.chatClient = ChatClient.builder(chatModel).defaultSystem("你是一个旅游规划师,请根据用户的需求提供旅游规划建议。").defaultAdvisors(new MessageChatMemoryAdvisor(new RedisChatMemory("127.0.0.1",6379,null))).build();}/*** 从Redis中获取聊天内容* 根据提供的prompt和chatId,从Redis中检索聊天内容,并以Flux<String>的形式返回** @param prompt 聊天内容的提示或查询关键字* @param chatId 聊天的唯一标识符,用于从Redis中检索特定的聊天内容* @param response HttpServletResponse对象,用于设置响应的字符编码为UTF-8* @return Flux<String> 包含聊天内容的反应式流*/@GetMapping("/redis")public Flux<String> redis(@RequestParam("prompt") String prompt,@RequestParam("chatId") String chatId,HttpServletResponse response) {response.setCharacterEncoding("UTF-8");return chatClient.prompt(prompt).advisors(a -> a.param(CHAT_MEMORY_CONVERSATION_ID_KEY, chatId).param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)).stream().content();}
其中的RedisChatMemory:
/**** 基于Redis的聊天记忆实现。* 该类实现了ChatMemory接口,提供了将聊天消息存储到Redis中的功能。** @author Fox*/
public class RedisChatMemory implements ChatMemory, AutoCloseable {private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class);private static final String DEFAULT_KEY_PREFIX = "chat:";private static final String DEFAULT_HOST = "127.0.0.1";private static final int DEFAULT_PORT = 6379;private static final String DEFAULT_PASSWORD = null;private final JedisPool jedisPool;private final ObjectMapper objectMapper;public RedisChatMemory() {this(DEFAULT_HOST, DEFAULT_PORT, DEFAULT_PASSWORD);}public RedisChatMemory(String host, int port, String password) {JedisPoolConfig poolConfig = new JedisPoolConfig();this.jedisPool = new JedisPool(poolConfig, host, port, 2000, password);this.objectMapper = new ObjectMapper();logger.info("Connected to Redis at {}:{}", host, port);}@Overridepublic void add(String conversationId, List<Message> messages) {String key = DEFAULT_KEY_PREFIX + conversationId;AtomicLong timestamp = new AtomicLong(System.currentTimeMillis());try (Jedis jedis = jedisPool.getResource()) {// 使用pipeline批量操作提升性能var pipeline = jedis.pipelined();messages.forEach(message ->pipeline.hset(key, String.valueOf(timestamp.getAndIncrement()), message.toString()));pipeline.sync();}logger.info("Added messages to conversationId: {}", conversationId);}@Overridepublic List<Message> get(String conversationId, int lastN) {String key = DEFAULT_KEY_PREFIX + conversationId;try (Jedis jedis = jedisPool.getResource()) {Map<String, String> allMessages = jedis.hgetAll(key);if (allMessages.isEmpty()) {return List.of();}return allMessages.entrySet().stream().sorted((e1, e2) ->Long.compare(Long.parseLong(e2.getKey()), Long.parseLong(e1.getKey()))).limit(lastN).map(entry -> new UserMessage(entry.getValue())).collect(Collectors.toList());}}@Overridepublic void clear(String conversationId) {String key = DEFAULT_KEY_PREFIX + conversationId;try (Jedis jedis = jedisPool.getResource()) {jedis.del(key);}logger.info("Cleared messages for conversationId: {}", conversationId);}@Overridepublic void close() {try (Jedis jedis = jedisPool.getResource()) {if (jedis != null) {jedis.close();logger.info("Redis connection closed.");}if (jedisPool != null) {jedisPool.close();logger.info("Jedis pool closed.");}}}public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {try {String key = DEFAULT_KEY_PREFIX + conversationId;try (Jedis jedis = jedisPool.getResource()) {List<String> all = jedis.lrange(key, 0, -1);if (all.size() >= maxLimit) {all = all.stream().skip(Math.max(0, deleteSize)).toList();}this.clear(conversationId);for (String message : all) {jedis.rpush(key, message);}}}catch (Exception e) {logger.error("Error clearing messages from Redis chat memory", e);throw new RuntimeException(e);}}}
- 第一次调用
提问:我想去三亚
可见成功将提问和回答都写入redis
- 第二次调用
成功读取记忆,并将新的问答结果写入
至此,我们成功完成了基于内存和Redis两种方式,实现大模型的多轮记忆对话!