当前位置: 首页 > news >正文

SpringAI 1.0.0 正式版——利用Redis存储会话(ChatMemory)

官方文档:Chat Memory :: Spring AI Reference

1. 引言

SpringAI 1.0.0 改动了很多地方,本文根据官方的InMemoryChatMemoryRepository实现了自定义的RedisChatMemoryRepository,并使用MessageWindowChatMemory创建ChatMemory

2. 实现

2.1. 添加依赖

<dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-openai</artifactId><version>1.0.0</version>
</dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>

注意:SpringAI 1.0.0的maven依赖有所改变,artifactId变化了

2.2. 配置文件

server:port: 8080
spring:ai:openai:api-key: xxx     # 填自己的api-keybase-url: https://api.deepseek.comchat:options:model: deepseek-chattemperature: 0.7data:redis:host: localhostport: 6379password: 123456

正确配置redis连接即可

api-key可以填deepseek的(需要购买,1块钱能用挺久)

2.3. RedisChatMemoryRepository

RedisChatMemoryRepository用于存储会话数据

这里参考InMemoryChatMemoryRepository与【SpringAI 1.0.0】 ChatMemory 转换为 Redis 存储_springai如何将数据保存到redis-CSDN博客

package com.njust.repository;import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.content.Media;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.util.MimeType;
import java.io.IOException;
import java.net.URL;
import java.util.*;
import java.util.stream.Collectors;public class RedisChatMemoryRepository implements ChatMemoryRepository {private final StringRedisTemplate stringRedisTemplate;  // 用于操作 Redisprivate final ObjectMapper objectMapper;    // 用于序列化和反序列化private final String PREFIX ;       // 存储对话的 Redis Key 前缀private final String CONVERSATION_IDS_SET;  // 存储对话ID的 Redis Keypublic RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper) {this(stringRedisTemplate, objectMapper, "chat:conversation:", "chat:all_conversation_ids");}public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX) {this(stringRedisTemplate, objectMapper, PREFIX, "chat:all_conversation_ids");}public RedisChatMemoryRepository(StringRedisTemplate stringRedisTemplate, ObjectMapper objectMapper, String PREFIX, String CONVERSATION_IDS_SET) {this.stringRedisTemplate = stringRedisTemplate;this.objectMapper = objectMapper;this.PREFIX = PREFIX;this.CONVERSATION_IDS_SET = CONVERSATION_IDS_SET;}// 获取所有 conversationId(KEYS 命令匹配 chat:*)@Overridepublic List<String> findConversationIds() {// 使用ZSet存储对话ID(更高效)// 获取对话ID集合(按时间倒序排序,越晚创建的对话ID排在前面)Set<String> conversationIds = stringRedisTemplate.opsForZSet().reverseRange(CONVERSATION_IDS_SET, 0, -1);if (conversationIds == null || conversationIds.isEmpty()) {return List.of();}return new ArrayList<>(conversationIds);}// 根据 conversationId 获取 Message 列表@Overridepublic List<Message> findByConversationId(String conversationId) {// 参数验证if (conversationId == null || conversationId.isEmpty()) {throw new IllegalArgumentException("conversationId cannot be null or empty");}List<String> list = stringRedisTemplate.opsForList().range(PREFIX + conversationId, 0, -1);if (list == null || list.isEmpty()) {return List.of();}return list.stream().map(json -> {try {// return objectMapper.convertValue(json, Message.class);   // 直接反序列化Message会报错return deserializeMessage(json);    // 手动反序列化} catch (IOException e) {throw new RuntimeException(e);}}).collect(Collectors.toList());}// 保存整个 Message 列表到指定 conversationId@Overridepublic void saveAll(String conversationId, List<Message> messages) {// 参数验证if (conversationId == null || conversationId.isEmpty()) {throw new IllegalArgumentException("conversationId cannot be null or empty");}// 先清除原有的 conversation 数据stringRedisTemplate.delete(PREFIX + conversationId);if (messages == null || messages.isEmpty()) {return;}List<String> list = messages.stream().map(message -> {try {return objectMapper.writeValueAsString(message);} catch (JsonProcessingException e) {throw new RuntimeException("Failed to serialize Message", e);}}).collect(Collectors.toList());stringRedisTemplate.opsForList().rightPushAll(PREFIX + conversationId, list);// 更新对话ID集合stringRedisTemplate.opsForZSet().add(CONVERSATION_IDS_SET, conversationId, System.currentTimeMillis());}// 删除指定 conversationId 的数据@Overridepublic void deleteByConversationId(String conversationId) {if (conversationId == null || conversationId.isEmpty()) {throw new IllegalArgumentException("conversationId cannot be null or empty");}stringRedisTemplate.delete(PREFIX + conversationId);stringRedisTemplate.opsForZSet().remove(CONVERSATION_IDS_SET, conversationId);}// 手动反序列化 Messagepublic Message deserializeMessage(String json) throws IOException {// 解析 JSON 字符串为 JsonNodeJsonNode jsonNode = objectMapper.readTree(json);// 获取 messageType 字段值if (!jsonNode.has("messageType")) {throw new IllegalArgumentException("Missing or invalid messageType field");}String messageType = jsonNode.get("messageType").asText();// 获取 text 字段值String text = jsonNode.has("text") ? jsonNode.get("text").asText() : "";// 获取 metadata 字段值Map<String, Object> metadata = getMetadata(jsonNode);// 获取 media 字段值List<Media> mediaList = getMediaList(jsonNode);return switch (MessageType.valueOf(messageType)) {case SYSTEM -> new SystemMessage(text);case USER -> UserMessage.builder().text(text).media(mediaList).metadata(metadata).build();case ASSISTANT -> {List<AssistantMessage.ToolCall> toolCalls = getToolCalls(jsonNode);yield new AssistantMessage(text, metadata, toolCalls, mediaList);}default -> throw new IllegalArgumentException("Unknown message type: " + messageType);};}private Media deserializeMedia(ObjectMapper mapper, JsonNode mediaNode) throws IOException {Media.Builder builder = Media.builder();// Handle MIME typeif (mediaNode.has("mimeType")) {JsonNode mimeNode = mediaNode.get("mimeType");String type = mimeNode.get("type").asText();String subtype = mimeNode.get("subtype").asText();builder.mimeType(new MimeType(type, subtype));}// Handle data - could be either URL string or byte arrayif (mediaNode.has("data")) {String data = mediaNode.get("data").asText();if (data.startsWith("http://") || data.startsWith("https://")) {builder.data(new URL(data));} else {// Assume it's base64 encoded binary databyte[] bytes = Base64.getDecoder().decode(data);builder.data(bytes);}}// Handle dataAsByteArray if present (overrides data if both exist)if (mediaNode.has("dataAsByteArray")) {byte[] bytes = Base64.getDecoder().decode(mediaNode.get("dataAsByteArray").asText());builder.data(bytes);}// Handle optional fieldsif (mediaNode.has("id")) {builder.id(mediaNode.get("id").asText());}if (mediaNode.has("name")) {builder.name(mediaNode.get("name").asText());}return builder.build();}private Map<String, Object> getMetadata(JsonNode jsonNode) {if (jsonNode.has("metadata")) {return objectMapper.convertValue(jsonNode.get("metadata"), new TypeReference<>() {});}return new HashMap<>();}private List<Media> getMediaList(JsonNode jsonNode) throws IOException {List<Media> mediaList = new ArrayList<>();if (jsonNode.has("media")) {for (JsonNode mediaNode : jsonNode.get("media")) {mediaList.add(deserializeMedia(objectMapper, mediaNode));}}return mediaList;}private List<AssistantMessage.ToolCall> getToolCalls(JsonNode jsonNode) {if (jsonNode.has("toolCalls")) {return objectMapper.convertValue(jsonNode.get("toolCalls"), new TypeReference<>() {});}return Collections.emptyList();}
}

主要的部分都写上注释了,应该比较好理解

需要注意的是反序列化Message需要手动进行

2.4. 注册Bean

package com.njust.config;import com.fasterxml.jackson.databind.ObjectMapper;
import com.njust.repository.ChatHistoryRepository;
import com.njust.repository.RedisChatHistoryRepository;
import com.njust.repository.RedisChatMemoryRepository;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.StringRedisTemplate;@Configuration
public class CommonConfiguration {@Beanpublic ChatMemoryRepository chatMemoryRepository(StringRedisTemplate stringRedisTemplate) {// 默认情况下,如果尚未配置其他存储库,则 Spring AI 会自动配置ChatMemoryRepository类型的 beanInMemoryChatMemoryRepository可以直接在应用程序中使用。// 这里手动创建内存聊天记忆存储库return new RedisChatMemoryRepository(stringRedisTemplate, new ObjectMapper());}@Beanpublic ChatMemory chatMemory(ChatMemoryRepository chatMemoryRepository) {// 注册聊天上下文记忆机制return MessageWindowChatMemory.builder().chatMemoryRepository(chatMemoryRepository).maxMessages(20)   // 聊天记忆条数.build();}@Bean// 通过OpenAI平台注入deepseek模型public ChatClient deepseekChatClient(OpenAiChatModel openAiChatModel, ChatMemory chatMemory) {return ChatClient.builder(openAiChatModel).defaultSystem("你是南京理工大学计算机科学与工程学院的一名研究生,你的名字叫小兰").defaultAdvisors(new SimpleLoggerAdvisor(),  // 配置日志AdvisorMessageChatMemoryAdvisor.builder(chatMemory).build()    // 绑定上下文记忆).build();}
}

这里用MessageWindowChatMemory创建ChatMemory,用于限制上下文记忆条数

2.5. Controller

package com.njust.controller;import com.njust.repository.ChatHistoryRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.content.Media;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;import java.util.List;
import java.util.Objects;import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;// @RequiredArgsConstructor 的作用是:为所有 final 字段或带有 @NotNull 注解的字段自动生成构造函数,实现简洁、安全的依赖注入
@RequiredArgsConstructor
@RestController
@RequestMapping("/ai")
public class ChatController {private final ChatClient deepseekChatClient;@RequestMapping(value = "/chat", produces = "text/html;charset=utf-8")public Flux<String> chat(@RequestParam("prompt") String prompt,@RequestParam("chatId") String chatId) {return deepseekChatClient.prompt().user(prompt).advisors(a -> a.param(CONVERSATION_ID, chatId)).stream().content();}
}

http://www.xdnf.cn/news/897283.html

相关文章:

  • Kafka 入门指南与一键部署
  • SpringCloud学习笔记-3
  • Linux命令基础(2)
  • 软件功能测试目的是啥?如何通过测试用例确保产品达标?
  • <2>-MySQL库的操作
  • Python 字典(dict)的高级用法与技巧
  • 跨平台游戏引擎 Axmol-2.6.1 发布
  • [论文阅读] 人工智能 | 利用负信号蒸馏:用REDI框架提升LLM推理能力
  • 使用vsftpd搭建FTP服务器(TLS/SSL显式加密)
  • 大模型与 NLP、Transformer 架构
  • vue3子组件获取并修改父组件的值
  • TTT讲师认证题目学习记录
  • C++算法训练营 Day10 栈与队列(1)
  • Java学习——正则表达式
  • PHP语言核心技术全景解析
  • 双碳时代,能源调度的难题正从“发电侧”转向“企业侧”
  • MySQL体系架构解析(二):MySQL目录与启动配置全解析
  • React从基础入门到高级实战:React 实战项目 - 项目三:实时聊天应用
  • Linux容器篇、第二章_01Ubuntu22 环境下 KubeSphere 容器平台高可用搭建全流程
  • 悲观锁和乐观锁
  • 数据库SQLite基础
  • 《完全背包》题集
  • 天机学堂(学习计划和进度)
  • TDengine 开发指南——无模式写入
  • vue-20(Vuex 状态管理的最佳实践)
  • 如何配置nginx解决前端跨域请求问题
  • Nuxt.js 中的路由配置详解
  • (转)什么是DockerCompose?它有什么作用?
  • Ubuntu 基于sdl 音频学习的基础代码
  • 市面上哪款AI开源软件做ppt最好?