当前位置: 首页 > news >正文

LLM笔记(十一)常见解码/搜索算法

1. 束搜索 (Beam Search)
hugging face transformer实现源码
beam search详细说明blog

  • 原理: 束搜索是一种启发式图搜索算法,它在生成序列的每一步保留固定数量(束宽 k)的最可能的候选序列(称为“束”或“假设”)。它试图通过探索更广的搜索空间来找到比贪心搜索更高概率的整体序列。

  • 工作流程:

    1. 初始化: 从起始符或基于输入提示开始,生成第一批token,形成初始的候选序列。通常,如果直接从提示开始,第一步会选择概率最高的 k 个token作为 k 个初始束。
    2. 扩展 (Expansion): 在第 t 步,对于当前保留的 k 个束(每个束是一个部分生成的序列),将每个束扩展一位。即,对每个束,预测所有可能的下一个token及其概率。这将产生 k * V 个潜在的新序列(V是词汇表大小)。
    3. 评估 (Evaluation): 计算这 k * V 个扩展序列的累积概率(通常是log概率之和,因为连乘小概率数会导致下溢)。为了处理不同长度序列的偏好问题,有时会使用长度归一化 (log_prob / length^alpha) 或其他更复杂的评分函数。
    4. 剪枝 (Pruning): 从所有 k * V 个扩展序列中,选择得分最高的 k 个序列,作为第 t+1 步的新束。其他序列被丢弃。
    5. 分组与合并(可选但常见): 如果多个扩展序列最终以相同的token结尾,但来自不同的父束,它们在下一步扩展时可以被视为不同的路径。如果遇到 <eos>(序列结束符),该束完成,将其放入“已完成束”的列表中。当“已完成束”的数量达到 k 个,或者所有活动束都已完成,或者达到最大长度时,搜索可以提前终止。
    6. 终止: 当所有保留的 k 个束都生成了结束符 (<eos>),或者都达到了预设的最大长度时,搜索结束。
    7. 最终选择: 从所有“已完成束”中,选择得分最高(通常经过长度归一化后)的那个序列作为最终输出。
  • 特点:

    • 平衡探索与利用: 在贪心搜索(只利用)和穷举搜索(完全探索,不可行)之间取得平衡。
    • 更高质量的输出: 通常比贪心搜索产生更连贯、更符合上下文、整体概率更高的序列。
    • 束宽 k 的影响:
      • 较小的 k (如 k=1 即为贪心搜索): 计算快,但质量可能不高。
      • 较大的 k 探索更广,可能产生更好的结果,但计算和内存成本显著增加。k 过大时,边际效益递减。
    • 内存需求大: 需要同时维护 k 个候选序列及其完整的KV缓存历史。
    • 不保证全局最优: 仍然是启发式,可能会因为过早剪枝而错失真正的最优序列。
    • 可能偏好短序列: 如果不进行长度归一化,累积(负log)概率会倾向于选择较短的序列,因为每多一个token,概率(小于1)会进一步减小(负log概率会增大)。
  • 适用场景: 对生成文本质量要求较高的任务,如机器翻译、文本摘要、语音识别的文本转录等。

  • PagedAttention对其优化: 正如我们之前讨论的,PagedAttention通过KV缓存共享和写时复制 (CoW) 显著降低了束搜索的内存占用,使得可以使用更大的束宽或处理更长的序列。

  • 代码说明 (概念性,高度简化):

    def beam_search_decoder(prompt_tokens, model, tokenizer, beam_width, max_length):# beams: 列表,每个元素是 (log_probability, [token_ids])# 初始化第一个时间步的束initial_logits = model.forward(prompt_tokens)[:, -1, :]initial_log_probs = torch.log_softmax(initial_logits, dim=-1)top_k_log_probs, top_k_indices = torch.topk(initial_log_probs, beam_width, dim=-1)beams = []for i in range(beam_width):token_id = top_k_indices[0, i].item()log_prob = top_k_log_probs[0, i].item()# 初始序列包含提示和第一个生成的tokencurrent_sequence = prompt_tokens.squeeze(0).tolist() + [token_id]beams.append((log_prob, current_sequence))completed_beams = [] # 存储已遇到 EOS 的完整束for step in range(max_length - 1): # -1 因为第一步已经走了if not beams: # 如果所有束都已完成breakall_candidates = [] # 存储当前时间步所有可能的扩展束# 扩展当前的所有活动束new_beams = []for current_log_prob, current_seq_ids in beams:if current_seq_ids[-1] == tokenizer.eos_token_id:# 如果束已结束,将其移至已完成列表,不再扩展# 为了保持beam_width,可以考虑是否立即补充或在最后统一处理completed_beams.append((current_log_prob, current_seq_ids))continue # 不再扩展此束# 将当前序列ID列表转换为tensor输入模型current_input_tensor = torch.tensor([current_seq_ids], device=model.device)# 获取下一个token的logitswith torch.no_grad(): # 推理时不需要梯度next_token_logits = model.forward(current_input_tensor)[:, -1, :]next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1)# 为当前束选择beam_width个最佳扩展# 实践中,是从所有 k*V 个候选者中选 top-k,这里简化为每个束独立选 top-k# 然后再从所有束的 top-k 扩展中选最终的 top-ktop_k_next_log_probs, top_k_next_indices = torch.topk(next_token_log_probs, beam_width, dim=-1)for i in range(beam_width):next_token_id = top_k_next_indices[0, i].item()next_log_prob = top_k_next_log_probs[0, i].item()new_seq_ids = current_seq_ids + [next_token_id]new_total_log_prob = current_log_prob + next_log_prob # Log概率是相加的all_candidates.append((new_total_log_prob, new_seq_ids))if not all_candidates: # 如果没有可扩展的了(例如所有活动束都遇到了EOS)beams = [] # 清空活动束continue# 从所有候选扩展中选择新的 top-k 束# 根据log概率排序,选择最高的 beam_width 个# 实际实现会更复杂,需要考虑已完成束和活动束的总数ordered_candidates = sorted(all_candidates, key=lambda x: x[0], reverse=True)beams = ordered_candidates[:beam_width] # 更新活动束# 如果已完成的束数量达到beam_width,可以考虑提前停止(可选策略)# if len(completed_beams) >= beam_width:#     break# 将剩余的活动束(可能未完成但达到max_length)也加入完成列表completed_beams.extend(beams)if not completed_beams:return "Error: No completed beams."# 选择最佳的已完成束 (可能需要长度归一化)# 这里简化为选择log概率最高的# length_penalty_alpha = 0.6 # 示例长度惩罚因子# best_beam = max(completed_beams, key=lambda x: x[0] / (len(x[1])**length_penalty_alpha) )best_beam = max(completed_beams, key=lambda x: x[0])return tokenizer.decode(best_beam[1][len(prompt_tokens.squeeze(0).tolist()):]) # 去掉原始prompt部分
    

    注:这是一个非常简化且概念性的束搜索实现,旨在说明基本流程。实际的生产级束搜索实现要复杂得多,需要高效地管理束的状态、处理KV缓存(这正是vLLM PagedAttention的用武之地)、EOS处理、长度惩罚、重复惩罚、最小长度约束等。Hugging Face Transformers库提供了经过优化的束搜索实现。

2. 贪心搜索 (Greedy Search)

  • 原理: 在每一步都选择当前模型预测概率最高的那个token作为输出,然后将这个token加入到已生成的序列中,作为下一步的条件。
  • 特点:
    • 简单快速: 计算成本最低。
    • 确定性: 相同输入和模型,输出总是一致。
    • 缺点: 容易陷入局部最优,可能导致重复、不连贯或缺乏多样性的文本。
  • 适用场景: 对速度要求极高,文本质量要求相对较低的场景,或作为基线对比。
  • 代码说明 (概念性):
    generated_tokens = []
    current_input = prompt_tokens
    for _ in range(max_length):logits = model.forward(current_input) # 获取模型输出的logitsnext_token_logits = logits[:, -1, :]  # 取最后一个时间步的logitsnext_token_id = torch.argmax(next_token_logits, dim=-1) # 选择概率最高的tokenif next_token_id == EOS_TOKEN_ID: # End Of Sequence tokenbreakgenerated_tokens.append(next_token_id.item())current_input = torch.cat([current_input, next_token_id.unsqueeze(0)], dim=-1) # 将新token加入输入return tokenizer.decode(generated_tokens)
    

3. 随机采样 (Random Sampling / Stochastic Sampling)

  • 原理: 在每一步,根据模型输出的整个词汇表的概率分布进行随机抽样来选择下一个token。
  • 特点:
    • 多样性: 引入随机性,可产生不同输出。
    • 可控性(通过温度 Temperature):
      • 温度 (Temperature, T): 调整概率分布的平滑度。
        • logits_scaled = logits / T
        • probabilities = softmax(logits_scaled)
        • 高温 (T > 1): 分布更平缓,更随机、有创意,可能不连贯。
        • 低温 (0 < T < 1): 分布更尖锐,更接近贪心,保守可预测。
        • T = 1: 原始概率。
    • 缺点: 纯随机(尤其高温)可能产生不相关或无意义文本。
  • 适用场景: 需要多样化、创意文本的场景(故事、对话、诗歌)。
  • 代码说明 (概念性,含温度):
    generated_tokens = []
    current_input = prompt_tokens
    temperature = 0.7 # 示例温度值
    for _ in range(max_length):logits = model.forward(current_input)next_token_logits = logits[:, -1, :]# 应用温度scaled_logits = next_token_logits / temperatureprobabilities = torch.softmax(scaled_logits, dim=-1)# 根据概率分布进行采样next_token_id = torch.multinomial(probabilities, num_samples=1) if next_token_id.item() == EOS_TOKEN_ID:breakgenerated_tokens.append(next_token_id.item())current_input = torch.cat([current_input, next_token_id], dim=-1) # multinomial返回(batch,1)return tokenizer.decode(generated_tokens)
    

4. Top-K 采样 (Top-K Sampling)

  • 原理: 在每一步,先选出概率最高的 K 个token,然后仅在这 K 个token中根据它们的(重新归一化的)概率分布进行随机采样。
  • 特点:
    • 平衡多样性与连贯性: 限制采样池,减少低概率、不相关token的出现。
    • K 的影响: 小K保守,大K多样。
    • 缺点: 固定K值不一定适应所有概率分布形状。
  • 适用场景: 大多数需要随机性的文本生成任务,作为纯随机采样的改进。
  • 代码说明 (概念性):
    generated_tokens = []
    current_input = prompt_tokens
    k = 50 # 示例K值
    for _ in range(max_length):logits = model.forward(current_input)next_token_logits = logits[:, -1, :]# 获取Top-K的logits和indicestop_k_logits, top_k_indices = torch.topk(next_token_logits, k, dim=-1)# 在Top-K的logits上应用softmax进行归一化top_k_probabilities = torch.softmax(top_k_logits, dim=-1)# 从Top-K中采样一个token的索引 (相对于top_k_indices的索引)sampled_index_in_top_k = torch.multinomial(top_k_probabilities, num_samples=1)# 获取实际的token IDnext_token_id = top_k_indices.gather(-1, sampled_index_in_top_k)if next_token_id.item() == EOS_TOKEN_ID:breakgenerated_tokens.append(next_token_id.item())current_input = torch.cat([current_input, next_token_id], dim=-1)return tokenizer.decode(generated_tokens)
    

5. Top-P (Nucleus) 采样 (Top-P Sampling / Nucleus Sampling)

  • 原理: 在每一步,选择一个累积概率总和达到阈值 P 的最小token集合(“核心”),然后仅在这个核心集合中进行随机采样。
  • 特点:
    • 动态调整采样范围: 核心集合大小随概率分布形状而变。
    • 更好的质量控制: 通常比Top-K能产生更高质量且多样化的文本。
    • P 的影响: P接近1多样性高,P小则保守。
  • 适用场景: 高质量文本生成的首选随机策略之一。
  • 代码说明 (概念性):
    generated_tokens = []
    current_input = prompt_tokens
    p = 0.9 # 示例P值
    for _ in range(max_length):logits = model.forward(current_input)next_token_logits = logits[:, -1, :]probabilities = torch.softmax(next_token_logits, dim=-1)# 对概率进行排序并计算累积概率sorted_probabilities, sorted_indices = torch.sort(probabilities, descending=True, dim=-1)cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)# 找到累积概率首次超过P的阈值点# 核心集合是累积概率 <= P 的token,以及第一个 > P 的token# (实际实现中,通常是选择累积概率 >=P 的最小集合,或者保留那些概率 > (1-P)/V 的token)# 简化版:选择累积概率首次超过P的那些tokennucleus_indices_mask = cumulative_probabilities < p # 至少包含一个概率最高的词if not torch.any(nucleus_indices_mask):nucleus_indices_mask[..., 0] = True # 若所有词累加都不够p,至少选第一个else:# 将第一个超过p的也包含进来last_false_idx = (~nucleus_indices_mask).nonzero(as_tuple=True)[-1]if last_false_idx.numel() > 0 and last_false_idx.item() < nucleus_indices_mask.size(-1) -1:# 找到第一个false,将其前一个true(即最后一个小于p的)的下一个位置(即第一个大于p的)也包含true_indices = nucleus_indices_mask.nonzero(as_tuple=True)if true_indices[-1].numel() > 0:first_gt_p_idx = true_indices[-1][-1] + 1if first_gt_p_idx < nucleus_indices_mask.size(-1):nucleus_indices_mask[..., first_gt_p_idx] = True# 从原始probabilities中过滤出核心集合的概率masked_probabilities = probabilities.clone()# 将不在核心集合的token概率置为0# 这里需要将 nucleus_indices_mask 应用回原始的、未排序的概率上# 为了简化,我们假设可以直接在排序后的概率上操作并映射回原始索引# 实际操作:# 1. 找到满足条件的 sorted_indices# 2. 在这些 sorted_indices 对应的原始 probabilities 上进行采样indices_to_remove = sorted_indices[..., nucleus_indices_mask.sum().item():] # 假设mask是连续true然后falsemasked_probabilities.scatter_(-1, indices_to_remove, 0)if masked_probabilities.sum() == 0: # 如果核心集合为空或概率和为0(不太可能发生)next_token_id = torch.argmax(probabilities, dim=-1, keepdim=True) # 退化为贪心else:# 重新归一化核心集合的概率renormalized_probabilities = masked_probabilities / masked_probabilities.sum(dim=-1, keepdim=True)next_token_id = torch.multinomial(renormalized_probabilities, num_samples=1)if next_token_id.item() == EOS_TOKEN_ID:breakgenerated_tokens.append(next_token_id.item())current_input = torch.cat([current_input, next_token_id], dim=-1)return tokenizer.decode(generated_tokens)
    
    注:Top-P的PyTorch原生实现通常更复杂,涉及到排序和动态切片,上述代码是一个高度简化的概念说明,旨在表达核心思想。实际库如Hugging Face Transformers有优化好的实现。

6. 对比解码 (Contrastive Search / Contrastive Decoding)

  • 原理: 鼓励模型生成与之前上下文不那么相似(避免重复)且同时保持高模型概率的token。通常会有一个“退化惩罚 (degeneration penalty)”项。
  • 特点:
    • 减少重复: 有效缓解LLM的重复生成。
    • 提高连贯性和信息量。
    • 确定性(常见实现)。
  • 适用场景: 长文本生成、避免重复、保持高质量场景。
  • 代码说明 (概念性,核心思想):
    # Contrastive Search 通常不直接修改采样过程,
    # 而是在选择最终token时,基于一个结合了模型置信度 (model confidence) 
    # 和多样性/不重复性 (diversity/dis-similarity) 的目标函数。# 在每一步:
    # 1. 得到当前时间步所有token的概率 P_model(token | context)
    # 2. 对于每个候选token,计算其与前文的相似度/重复度 S(token, context)
    # 3. 目标函数可能类似: Score(token) = alpha * P_model(token | context) - beta * S(token, context)
    #    其中 alpha 和 beta 是超参数。
    # 4. 选择 Score(token) 最高的token。# 这是一个非常高层次的抽象,实际论文有具体公式和实现细节。
    # 例如,SimCTG (A Simple Contrastive Framework for Neural Text Generation) 论文。
    

7. 定向解码 (Constrained Decoding / Guided Decoding)

  • 原理: 在解码中引入外部约束或指导信号,使输出符合特定格式、主题、情感等。
  • 实现方式:
    • 硬约束: 如使用有限状态自动机 (FSA) 或解析器来强制语法正确性。
    • 软约束: 修改logits,例如,通过增加/减少特定token或token类的概率。
    • 插件式: 结合小型判别模型或规则。
  • 特点: 生成更可控、符合需求的文本。
  • 适用场景: 特定格式数据生成、知识图谱问答、可控情感对话。
  • 代码说明 (概念性,软约束-增加某些词的概率):
    generated_tokens = []
    current_input = prompt_tokens
    # 假设我们想引导模型多使用 "科技" 和 "创新" 这两个词
    positive_bias_tokens = [tokenizer.encode("科技")[0], tokenizer.encode("创新")[0]]
    bias_strength = 5.0 for _ in range(max_length):logits = model.forward(current_input)next_token_logits = logits[:, -1, :]# 应用软约束 (增加特定token的logit值)for token_id in positive_bias_tokens:next_token_logits[..., token_id] += bias_strength# 之后可以接贪心、采样等策略next_token_id = torch.argmax(next_token_logits, dim=-1) # 示例用贪心if next_token_id.item() == EOS_TOKEN_ID:breakgenerated_tokens.append(next_token_id.item())current_input = torch.cat([current_input, next_token_id.unsqueeze(0)], dim=-1)return tokenizer.decode(generated_tokens)
    

8. 多轮迭代优化/编辑 (Iterative Refinement / Editing)

  • 原理: 先生成初稿,然后通过一个或多个编辑/优化步骤(可能由另一个LLM或特定模型完成)来改进文本。
  • 特点: 逐步提升文本质量,修复错误。
  • 适用场景: 对文本质量要求极高的任务。
  • 代码说明 (概念性流程):
    def generate_initial_draft(prompt, base_llm, generation_params):# ... 使用上述某种解码策略生成初稿 ...return draft_textdef refine_text(text_to_refine, refinement_llm, refinement_prompt):# ... refinement_llm 根据 refinement_prompt (例如 "请修正以下文本的语法错误并使其更流畅") 来编辑 text_to_refine ...return refined_textinitial_prompt = "写一篇关于人工智能未来的短文。"
    draft = generate_initial_draft(initial_prompt, main_llm, {"strategy": "top_p", "p": 0.9})# 可以进行多轮优化
    refined_draft_1 = refine_text(draft, editor_llm, "请检查并修正语法和流畅性。")
    final_output = refine_text(refined_draft_1, specialist_llm, "请从更专业的角度审阅并增强技术深度。")print(final_output)
    
http://www.xdnf.cn/news/556003.html

相关文章:

  • canvas浅析(一)
  • Java 09Stream流与File类
  • ragas precision计算的坑
  • 使用frp内网穿透本地的虚拟机
  • 几款常用的虚拟串口模拟器
  • springboot+vue3实现在线购物商城系统
  • MS16-075 漏洞 复现过程
  • Ai学习之openai api
  • 武汉火影数字|数字展厅展馆制作:沉浸式体验,全方位互动
  • Vue 3 深度解析:Composition API、Pinia状态管理与路由守卫实战
  • Rocketmq leader选举机制,通过美国大选解释
  • 第32节:基于ImageNet预训练模型的迁移学习与微调
  • 【MySQL】第六弹——表的CRUD进阶(四)聚合查询(下)
  • 图的几种存储方法比较:二维矩阵、邻接表与链式前向星
  • 人工智能驱动的制造业智能决策:从生产排程到质量闭环控制
  • 深度学习-mmcv中build_runner实例化全流程详解
  • EtherCAT通信协议
  • 【Netty】- NIO基础2
  • 易境通海外仓系统PDA蓝牙面单打印:解锁库内作业新姿势
  • 【MySQL成神之路】运算符总结
  • day 31
  • STM32之定时器(TIMER)与脉冲宽度调制(PWM)
  • Glasgow Smile: 2靶场渗透
  • PostGIS栅格数据类型解析【raster】
  • 【深入理解索引扩展—1】提升智能检索系统召回质量的3大利器
  • 详解ip地址、子网掩码、网关、广播地址
  • 系统编程的标准IO
  • 【LINUX操作系统】日志系统——自己实现一个简易的日志系统
  • 容器环境渗透测试工具(docker渗透测试工具、kubernetes)
  • 一文掌握vue3基础,适合自学入门案例丰富