当前位置: 首页 > java >正文

Prompt工程学习之思维树(TOT)

思维树

定义思维树(Tree of Thoughts, ToT) 是一种先进的推理框架,它通过同时探索多条推理路径对思维链(Chain of Thought)** 进行了扩展。该技术将问题解决视为一个搜索过程 —— 模型生成不同的中间步骤,评估这些步骤的可行性,并探索最有希望的路径。

Tree of Thoughts (ToT) 是一种大语言模型推理框架,通过树状结构探索多条推理路径,允许模型自我评估路径可行性并回溯调整,模拟人类解决复杂问题时的 “试错 - 评估 - 选择” 过程。

目标:解决传统 LLMs 逐 Token 单向决策的局限,提升在需要探索、战略前瞻或多步规划任务(如数学推理、创意写作、谜题)中的表现。

ToT 框架核心机制

  • 核心思路:将问题解决视为树状搜索过程,通过生成 ** 连贯的中间思维单元(Thoughts)** 作为推理的中间步骤,而非单一 Token。
  • 关键能力:多路径探索:同时生成多条推理路径(如不同的解题思路)。
  • 自我评估:评估每条路径的可行性,选择最有希望的分支继续探索。
  • 回溯决策:必要时回溯到之前的思维节点,调整后续策略(类似人类解题的试错过程)。与 Chain of Thought(CoT)的区别:

与COT的对比

CoT 仅生成单一推理链,而 ToT 支持并行探索多条链,并通过评估机制实现全局最优决策。

24点案例

使用数字4、9、10和13以及四种基本运算符(+、-、/、*),生成一个结果为24的表达式。

step1
输入:4, 9, 10, 13  
可能的下一步操作:  
- 4 + 9 = 13(剩余:13, 10, 13- 10 - 4 = 6(剩余:6, 9, 13- 13 - 10 = 3(剩余:4, 9, 3- 9 × 4 = 36(剩余:36, 10, 13- 10 ÷ 4 = 2.5(剩余:2.5, 9, 13)输入:4, 9, 10, 13  
请给出可能得下一步操作输出:
4+9 = 13 (left: 13, 10, 13)
10-4 = 6 (left: 6, 9, 13)
13-9 = 4 (left: 4, 9, 10)
...
...step2
计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
13, 10, 13:输出:
10 + 13 + 13 = 36 impossible
...
...计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
6, 9, 13:输出:
6 *  (13-9) = 24 sure
...
...

自动化代码示例
生成思维结点,以树状形式组织;沿着思维结点进行探索,评估结果;根据评估结果选择下一步操作

package com.example.tot24;import ai.spring.ai.client.ChatClient;
import ai.spring.ai.client.Generation;
import ai.spring.ai.client.Message;
import ai.spring.ai.client.chat.ChatResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;public class Tot24Application {// 思维树节点类static class TreeNode {private List<Double> numbers;private List<String> history;private List<TreeNode> children;private double score;private boolean terminal;}// 候选操作类static class CandidateOperation {private String operation;private List<Double> expectedNumbers;private String reason;private double score;private String explanation;}// 24点游戏求解器static class TwentyFourSolver {private static final double TARGET = 24.0;private static final double TOLERANCE = 1e-6;private static final int MAX_STEPS = 5;private static final int BEAM_WIDTH = 3;private final ChatClient chatClient;private final String modelName;private final String systemPrompt;public TwentyFourSolver(ChatClient chatClient, String modelName) {this.chatClient = chatClient;this.modelName = modelName;// 构建系统提示this.systemPrompt = """你是一个解决24点游戏的专家。给定4个1-13之间的数字,使用加、减、乘、除和括号,使最终计算结果为24。解决过程中,请遵循以下规则:1. 每个数字必须且只能使用一次2. 中间步骤的计算结果可以是分数3. 最终答案必须是精确的24当被要求生成下一步操作时,请提供JSON格式的候选操作列表(最多5个有希望的操作):[{"operation": "具体操作(如:4+5=9)","expected_numbers": [操作后的数字列表],"reason": "选择该操作的理由"},...]当被要求评估状态时,请提供JSON格式的评分和解释:{"score": 3,"explanation": "理由..."}评分标准:- 1分:当前数字组合不可能得到24- 2分:可能得到24,但难度高- 3分:有合理可能性得到24- 4分:非常有希望得到24- 5分:已得到24""";}public Optional<String> solve(List<Integer> numbers) {List<Double> initialNumbers = numbers.stream().map(Double::valueOf).collect(Collectors.toList());TreeNode root = new TreeNode(initialNumbers, new ArrayList<>());Queue<TreeNode> queue = new LinkedList<>();queue.add(root);while (!queue.isEmpty()) {TreeNode currentNode = queue.poll();// 检查是否已解决if (currentNode.getNumbers().stream().anyMatch(n -> Math.abs(n - TARGET) < TOLERANCE)) {return Optional.of(formatSolution(currentNode));}// 生成候选操作List<CandidateOperation> candidates = generateCandidates(currentNode);// 评估候选操作evaluateCandidates(currentNode, candidates);// 选择最有希望的操作List<CandidateOperation> topCandidates = candidates.stream().sorted(Comparator.comparingDouble(CandidateOperation::getScore).reversed()).limit(BEAM_WIDTH).collect(Collectors.toList());// 创建子节点for (CandidateOperation candidate : topCandidates) {TreeNode childNode = new TreeNode(candidate.getExpectedNumbers(),new ArrayList<>(currentNode.getHistory()));childNode.getHistory().add(candidate.getOperation());childNode.setScore(candidate.getScore());currentNode.getChildren().add(childNode);// 如果分数足够高,继续探索if (candidate.getScore() >= 3) {queue.add(childNode);}}}return Optional.empty(); // 无解}private List<CandidateOperation> generateCandidates(TreeNode node) {String userPrompt = String.format("""当前状态:数字:%s历史:%s请生成最多5个有希望的下一步操作。""", node.getNumbers(), node.getHistory());String response = callLLM(userPrompt);try {// 解析JSON响应List<CandidateOperation> candidates = new ArrayList<>();// 实际应用中需要使用真正的JSON解析库// 这里简化处理,实际代码应使用Jackson等库return candidates;} catch (Exception e) {System.err.println("解析候选操作失败: " + e.getMessage());System.err.println("LLM响应: " + response);return Collections.emptyList();}}private void evaluateCandidates(TreeNode node, List<CandidateOperation> candidates) {for (CandidateOperation candidate : candidates) {String userPrompt = String.format("""当前状态:数字:%s历史:%s候选操作:%s操作后数字:%s请评分并解释。""", node.getNumbers(), node.getHistory(),candidate.getOperation(),candidate.getExpectedNumbers());String response = callLLM(userPrompt);try {// 解析JSON响应获取评分和解释// 实际应用中需要使用真正的JSON解析库// 这里简化处理double score = 3.0; // 默认值String explanation = "默认评估";candidate.setScore(score);candidate.setExplanation(explanation);} catch (Exception e) {System.err.println("解析评估结果失败: " + e.getMessage());System.err.println("LLM响应: " + response);candidate.setScore(2.0); // 保守评分}}}private String callLLM(String userPrompt) {Message systemMessage = new Message(systemPrompt, "system");Message userMessage = new Message(userPrompt, "user");ChatResponse response = chatClient.generate(List.of(systemMessage, userMessage), modelName);Generation generation = response.getGenerations().get(0);return generation.getContent();}private String formatSolution(TreeNode node) {StringBuilder sb = new StringBuilder();for (String step : node.getHistory()) {sb.append(step).append("\n");}return sb.toString();}}
}

参考

1.TOT 24点,https://learnprompting.org/docs/advanced/decomposition/tree_of_thoughts?srsltid=AfmBOor-YZUZ9nUIH-HpTtxJhTH-MHeQ_aQ6xp6to3gEveLlkqyttWq4
2.TOT,https://arxiv.org/abs/2305.10601

http://www.xdnf.cn/news/12756.html

相关文章:

  • C++课设:从零开始打造影院订票系统
  • .net 可以调试的Windows服务框架Topshelf
  • ClickHouse 25.3 json列类型使用示例
  • 基于自适应虚拟谐波阬的光储VSG并网电流谐波抑制模型
  • 归并排序:分治思想的高效排序
  • UDP 与 TCP 的区别是什么?
  • CppCon 2015 学习:Memory and C++ debugging at Electronic Arts
  • day6 cpp:c中处理字符串,c++string
  • 第二十周:Redis(二)
  • 条件语句易错点
  • Android 集成 Firebase 指南
  • 如何写一篇基于Spring Boot + Vue + 微信小程序的软件的接口文档
  • Tavily 技术详解:为大模型提供实时搜索增强的利器
  • 行为设计模式之Iterator(迭代器)
  • Ubuntu20.04中MySQL的安装和配置
  • 【iOS】JSONModel源码学习
  • LLMs 系列科普文(8)
  • 多线程语音识别工具
  • 【工具教程】多个条形码识别用条码内容对图片重命名,批量PDF条形码识别后用条码内容批量改名,使用教程及注意事项
  • 告别 @MockBean!在 Spring Boot 3.2+ 中使用 @MockitoBean 进行单元测试
  • 智慧园区管理平台
  • 阿里云Alibaba Cloud安装Docker与Docker compose【图文教程】
  • Spring 中的三级缓存机制详解
  • MySQL索引:7大类型+4维分类
  • 《Windows 10下QT+OpenCV+Yolo11:AI视觉开发实战指南》
  • GNSS高精度定位之-----星基差分
  • 数据网格的革命:从集中式到分布式的数据管理新范式
  • C++中的数组
  • Linux Docker的简介
  • uni-app学习笔记三十三--触底加载更多和下拉刷新的实现