spring-ai-alibaba官方 Playground 示例之联网搜索代码解析
1、联网搜索controller
/** Licensed to the Apache Software Foundation (ASF) under one or more* contributor license agreements. See the NOTICE file distributed with* this work for additional information regarding copyright ownership.* The ASF licenses this file to You under the Apache License, Version 2.0* (the "License"); you may not use this file except in compliance with* the License. You may obtain a copy of the License at** http://www.apache.org/licenses/LICENSE-2.0** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.* See the License for the specific language governing permissions and* limitations under the License.*/package com.alibaba.cloud.ai.application.controller;import com.alibaba.cloud.ai.application.service.SAAWebSearchService;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.servlet.http.HttpServletResponse;
import reactor.core.publisher.Flux;import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;/*** @author yuluo* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>** The deepseek-r1 model is used by default, which works better.*/@RestController
@Tag(name = "Web Search APIs")
@RequestMapping("/api/v1")
public class SAAWebSearchController {//联网搜索服务类private final SAAWebSearchService webSearch;public SAAWebSearchController(SAAWebSearchService webSearch) {this.webSearch = webSearch;}@PostMapping("/search")public Flux<String> search(HttpServletResponse response,@Validated @RequestBody String prompt) {response.setCharacterEncoding("UTF-8");return webSearch.chat(prompt);}}
2、联网搜索服务类
/** Licensed to the Apache Software Foundation (ASF) under one or more* contributor license agreements. See the NOTICE file distributed with* this work for additional information regarding copyright ownership.* The ASF licenses this file to You under the Apache License, Version 2.0* (the "License"); you may not use this file except in compliance with* the License. You may obtain a copy of the License at** http://www.apache.org/licenses/LICENSE-2.0** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.* See the License for the specific language governing permissions and* limitations under the License.*/package com.alibaba.cloud.ai.application.service;import com.alibaba.cloud.ai.application.advisor.ReasoningContentAdvisor;
import com.alibaba.cloud.ai.application.modulerag.WebSearchRetriever;
import com.alibaba.cloud.ai.application.modulerag.core.IQSSearchEngine;
import com.alibaba.cloud.ai.application.modulerag.data.DataClean;
import com.alibaba.cloud.ai.application.modulerag.join.ConcatenationDocumentJoiner;
import com.alibaba.cloud.ai.application.modulerag.prompt.CustomContextQueryAugmenter;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;import java.util.Map;
import java.util.logging.Logger;/*** @author yuluo* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>*/@Service
public class SAAWebSearchService {private final DataClean dataCleaner;private final ChatClient chatClient;private final QueryExpander queryExpander;private final QueryTransformer queryTransformer;private final WebSearchRetriever webSearchRetriever;private final SimpleLoggerAdvisor simpleLoggerAdvisor;private final PromptTemplate queryArgumentPromptTemplate;private final ReasoningContentAdvisor reasoningContentAdvisor;// It works better here with DeepSeek-R1private static final String DEFAULT_WEB_SEARCH_MODEL = "deepseek-r1";private static final Logger log = Logger.getLogger(SAAWebSearchService.class.getName());public SAAWebSearchService(DataClean dataCleaner,QueryExpander queryExpander,IQSSearchEngine searchEngine,QueryTransformer queryTransformer,SimpleLoggerAdvisor simpleLoggerAdvisor,@Qualifier("dashscopeChatModel") ChatModel chatModel,@Qualifier("queryArgumentPromptTemplate") PromptTemplate queryArgumentPromptTemplate) {//联网搜索服务类//数据清洗this.dataCleaner = dataCleaner;//查询重写this.queryTransformer = queryTransformer;//多查询扩展this.queryExpander = queryExpander;//联网搜索提示模板this.queryArgumentPromptTemplate = queryArgumentPromptTemplate;// reasoning content for DeepSeek-r1 is integrated into the outputthis.reasoningContentAdvisor = new ReasoningContentAdvisor(1);// Build chatClientthis.chatClient = ChatClient.builder(chatModel).defaultOptions(DashScopeChatOptions.builder().withModel(DEFAULT_WEB_SEARCH_MODEL)// stream 模式下是否开启增量输出.withIncrementalOutput(true).build()).build();this.simpleLoggerAdvisor = simpleLoggerAdvisor;this.webSearchRetriever = WebSearchRetriever.builder().searchEngine(searchEngine).dataCleaner(dataCleaner).maxResults(2).build();}public Flux<String> chat(String prompt) {Map<Integer, String> webLink = dataCleaner.getWebLink();return chatClient.prompt().advisors(createRetrievalAugmentationAdvisor(),reasoningContentAdvisor,simpleLoggerAdvisor).user(prompt).stream().content();// .transform(contentStream -> embedLinks(contentStream, webLink));}// todo 效果不好,这里只是一种思路// stream 中 [[ 可能是一个 chunk 输出,而 ]] 在另一个 stream 中。在遇到第一个 [[ 时,短暂阻塞,到 ]] 出现时,开始替换执行后续逻辑private Flux<String> embedLinks(Flux<String> contentStream, Map<Integer, String> webLink) {// State for managing incomplete tagsStringBuilder buffer = new StringBuilder();return contentStream.flatMap(chunk -> {StringBuilder output = new StringBuilder(); // Output for this chunkint i = 0;while (i < chunk.length()) {char c = chunk.charAt(i);if (c == '[' && i + 1 < chunk.length() && chunk.charAt(i + 1) == '[') {// Start of [[...]]buffer.append("[[");i += 2; // Skip [[} else if (buffer.length() > 0 && c == ']' && i + 1 < chunk.length() && chunk.charAt(i + 1) == ']') {// End of [[...]]buffer.append("]]");String tag = buffer.toString(); // Complete tagoutput.append(resolveLink(tag, webLink)); // Resolve and appendbuffer.setLength(0); // Clear bufferi += 2; // Skip ]]} else if (buffer.length() > 0) {// Inside [[...]]buffer.append(c);i++;} else {// Normal textoutput.append(c);i++;}}// If buffer still contains data, leave it for the next chunkreturn Flux.just(output.toString());}).concatWith(Flux.defer(() -> {// If there's any leftover in the buffer, append it as-isif (buffer.length() > 0) {return Flux.just(buffer.toString());}return Flux.empty();}));}private String resolveLink(String tag, Map<Integer, String> webLink) {// Extract the number inside [[...]] and resolve the URLif (tag.startsWith("[[") && tag.endsWith("]]")) {String keyStr = tag.substring(2, tag.length() - 2); // Remove [[ and ]]try {int key = Integer.parseInt(keyStr);if (webLink.containsKey(key)) {return "[" + key + "](" + webLink.get(key) + ")";}} catch (NumberFormatException e) {// Not a valid number, return the original tag}}return tag; // Return original tag if no match}private RetrievalAugmentationAdvisor createRetrievalAugmentationAdvisor() {
// 使用RetrievalAugmentationAdvisor增强查询效果return RetrievalAugmentationAdvisor.builder()
// 配置文档检索器.documentRetriever(webSearchRetriever)
// 查询重写.queryTransformers(queryTransformer).queryAugmenter(new CustomContextQueryAugmenter(queryArgumentPromptTemplate,null,true))//多查询扩展.queryExpander(queryExpander).documentJoiner(new ConcatenationDocumentJoiner()).build();}}
3、查询重写和联网搜索提示模板
/** Licensed to the Apache Software Foundation (ASF) under one or more* contributor license agreements. See the NOTICE file distributed with* this work for additional information regarding copyright ownership.* The ASF licenses this file to You under the Apache License, Version 2.0* (the "License"); you may not use this file except in compliance with* the License. You may obtain a copy of the License at** http://www.apache.org/licenses/LICENSE-2.0** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.* See the License for the specific language governing permissions and* limitations under the License.*/package com.alibaba.cloud.ai.application.modulerag.prompt;import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;/*** Prompt:* 1. https://zhuanlan.zhihu.com/p/23929522431* 2. https://cloud.tencent.com/developer/article/2509465** @author yuluo* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>*/@Configuration
public class PromptTemplateConfig {@Beanpublic PromptTemplate transformerPromptTemplate() {//查询重写 提示模板return new PromptTemplate("""Given a user query, rewrite the user question to provide better results when querying {target}.You should follow these rules:1. Remove any irrelevant information and make sure the query is concise and specific;2. The output must be consistent with the language of the user's query;3. Ensure better understanding and answers from the perspective of large models.Original query:{query}Query after rewrite:""");}@Beanpublic PromptTemplate queryArgumentPromptTemplate() {
// 联网搜索提示模板return new PromptTemplate("""You'll get a set of document contexts that are relevant to the issue.Each document begins with a reference number, such as [[x]], where x is a number that can be repeated.Documents that are not referenced will be marked as [[null]].Use context and refer to it at the end of each sentence, if applicable.The context information is as follows:---------------------{context}---------------------Generate structured responses to user questions given contextual information and without prior knowledge.When you answer user questions, follow these rules:1. If the answer is not in context, say you don't know;2. Don't provide any information that is not relevant to the question, and don't output any duplicate content;3. Avoid using "context-based..." or "The provided information..." said;4. Your answers must be correct, accurate, and written in an expertly unbiased and professional tone;5. The appropriate text structure in the answer is determined according to the characteristics of the content, please include subheadings in the output to improve readability;6. When generating a response, provide a clear conclusion or main idea first, without a title;7. Make sure each section has a clear subtitle so that users can better understand and refer to your output;8. If the information is complex or contains multiple sections, make sure each section has an appropriate heading to create a hierarchical structure;9. Please refer to the sentence or section with the reference number at the end in [[x]] format;10. If a sentence or section comes from more than one context, list all applicable references, e.g. [[x]][[y]];11. Your output answers must be in beautiful and rigorous markdown format.12. Because your output is in markdown format, please include the link in the reference document in the form of a hyperlink when referencing the context, so that users can click to view it;13. If a reference is marked as [[null]], it does not have to be cited;14. Except for Code. Aside from the specific name and citation, your answer must be written in the same language as the question.User Issue:{query}Your answer:""");}
}
4、多查询扩展
/** Licensed to the Apache Software Foundation (ASF) under one or more* contributor license agreements. See the NOTICE file distributed with* this work for additional information regarding copyright ownership.* The ASF licenses this file to You under the Apache License, Version 2.0* (the "License"); you may not use this file except in compliance with* the License. You may obtain a copy of the License at** http://www.apache.org/licenses/LICENSE-2.0** Unless required by applicable law or agreed to in writing, software* distributed under the License is distributed on an "AS IS" BASIS,* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.* See the License for the specific language governing permissions and* limitations under the License.*/package com.alibaba.cloud.ai.application.modulerag.preretrieval.query.expansion;import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;/*** User Prompt Query Expander** @author yuluo* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>*/public class MultiQueryExpander implements QueryExpander {private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);// 多查询扩展 提示模板private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""You are an expert in information retrieval and search optimization.Generate {number} different versions of a given query.Each variation should cover a different perspective or aspect of the topic while maintaining the core intent ofthe original query. The goal is to broaden your search and improve your chances of finding relevant information.Don't interpret the selection or add additional text.Query variants are provided, separated by line breaks.Original query: {query}Query variants:""");private static final Boolean DEFAULT_INCLUDE_ORIGINAL = true;private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;private final ChatClient chatClient;private final PromptTemplate promptTemplate;private final boolean includeOriginal;private final int numberOfQueries;public MultiQueryExpander(ChatClient.Builder chatClientBuilder,@Nullable PromptTemplate promptTemplate,@Nullable Boolean includeOriginal,@Nullable Integer numberOfQueries) {Assert.notNull(chatClientBuilder, "ChatClient.Builder must not be null");this.chatClient = chatClientBuilder.build();this.promptTemplate = promptTemplate == null ? DEFAULT_PROMPT_TEMPLATE : promptTemplate;this.includeOriginal = includeOriginal == null ? DEFAULT_INCLUDE_ORIGINAL : includeOriginal;this.numberOfQueries = numberOfQueries == null ? DEFAULT_NUMBER_OF_QUERIES : numberOfQueries;PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");}@NotNull@Overridepublic List<Query> expand(@Nullable Query query) {Assert.notNull(query, "Query must not be null");logger.debug("Generating {} queries for query: {}", this.numberOfQueries, query.text());String resp = this.chatClient.prompt().user(user -> user.text(this.promptTemplate.getTemplate()).param("number", this.numberOfQueries).param("query", query.text())).call().content();logger.debug("MultiQueryExpander#expand() Response from chat client: {}", resp);if (Objects.isNull(resp)) {logger.warn("No response from chat client for query: {}. is return.", query.text());return List.of(query);}List<String> queryVariants = Arrays.stream(resp.split("\n")).filter(StringUtils::hasText).toList();if (CollectionUtils.isEmpty(queryVariants) || this.numberOfQueries != queryVariants.size()) {logger.warn("Query expansion result dose not contain the requested {} variants for query: {}. is return.",this.numberOfQueries, query.text());return List.of(query);}List<Query> queries = queryVariants.stream().filter(StringUtils::hasText).map(queryText -> query.mutate().text(queryText).build()).collect(Collectors.toList());if (this.includeOriginal) {logger.debug("Including original query in the expanded queries for query: {}", query.text());queries.add(0, query);}logger.debug("Rewrite queries: {}", queries);return queries;}public static Builder builder() {return new Builder();}public static final class Builder {private ChatClient.Builder chatClientBuilder;private PromptTemplate promptTemplate;private Boolean includeOriginal;private Integer numberOfQueries;private Builder() {}public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {this.chatClientBuilder = chatClientBuilder;return this;}public Builder promptTemplate(PromptTemplate promptTemplate) {this.promptTemplate = promptTemplate;return this;}public Builder includeOriginal(Boolean includeOriginal) {this.includeOriginal = includeOriginal;return this;}public Builder numberOfQueries(Integer numberOfQueries) {this.numberOfQueries = numberOfQueries;return this;}public MultiQueryExpander build() {return new MultiQueryExpander(this.chatClientBuilder, this.promptTemplate, this.includeOriginal, this.numberOfQueries);}}}
5、实例化查询重写和多查询扩展
package com.alibaba.cloud.ai.application.config;//import com.alibaba.cloud.ai.application.rag.postretrieval.DashScopeDocumentRanker;
import com.alibaba.cloud.ai.application.modulerag.preretrieval.query.expansion.MultiQueryExpander;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;/*** @author yuluo* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>*/@Configuration
public class WeSearchConfiguration {
//
// @Bean
// public DashScopeDocumentRanker dashScopeDocumentRanker(
// RerankModel rerankModel
// ) {
// return new DashScopeDocumentRanker(rerankModel);
// }@Beanpublic QueryTransformer queryTransformer(@Qualifier("dashscopeChatModel") ChatModel chatModel,@Qualifier("transformerPromptTemplate") PromptTemplate transformerPromptTemplate) {//实例化查询重写ChatClient chatClient = ChatClient.builder(chatModel).defaultOptions(DashScopeChatOptions.builder().withModel("qwen-plus").build()).build();
// 创建查询重写转换器return RewriteQueryTransformer.builder().chatClientBuilder(chatClient.mutate()).promptTemplate(transformerPromptTemplate).targetSearchSystem("Web Search").build();}@Beanpublic QueryExpander queryExpander(@Qualifier("dashscopeChatModel") ChatModel chatModel) {//实例化查询扩展ChatClient chatClient = ChatClient.builder(chatModel).defaultOptions(DashScopeChatOptions.builder().withModel("qwen-plus").build()).build();
//多查询扩展是提高RAG系统检索效果的关键技术。在实际应用中,
// 用户的查询往往是简短且不完整的,这可能导致检索结果不够准确或完整。
// Spring AI提供了强大的多查询扩展机制,能够自动生成多个相关的查询变体,
// 从而提高检索的准确性和召回率return MultiQueryExpander.builder().chatClientBuilder(chatClient.mutate()).numberOfQueries(2).build();}}