当前位置: 首页 > ops >正文

Numerical Difference between vLLM logprobs and huggingface logprobs

来自 https://fengyao.notion.site/off-policy-rl#246721e3f6c480259e6ff598ac4c317b 中引用的code:

# VLLM Side
import torch
from vllm import LLM, SamplingParams
import mathif __name__ == '__main__':TEMPERATURE = 0.7DTYPE = torch.bfloat16llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", dtype=DTYPE, enforce_eager=True)# model = llm.llm_engine.model_executor.driver_worker.model_runner.model# saved_outputs = []# def logits_processor_hook(module, input, output):#     assert isinstance(output, torch.Tensor)#     saved_outputs.append(output.clone())# model.logits_processor.register_forward_hook(logits_processor_hook)prompts = ["One of the most important things in life is to","The answer to 1 + 1 is",]outputs = llm.generate(prompts,sampling_params=SamplingParams(max_tokens=512,temperature=TEMPERATURE,logprobs=2,),)save_stuff = []for output in outputs:assert len(output.outputs[0].token_ids) == len(output.outputs[0].logprobs)#for token, logprob in zip(output.outputs[0].token_ids, output.outputs[0].logprobs):#print(token, logprob)save_stuff.append({"input_ids": output.prompt_token_ids,"output_ids": output.outputs[0].token_ids,"logprobs": output.outputs[0].logprobs,})# HF Sidetorch.cuda.set_device(1)from transformers import AutoModelForCausalLM, AutoTokenizerimport torchimport torch.nn.functional as Fmodel = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", torch_dtype=DTYPE, device_map="cuda")tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")seq_id = 0vllm_errs = []# hook_errs = []vllm_prob_errs = []# hook_prob_errs = []for output in save_stuff:token_ids = torch.tensor([*output["input_ids"], *output["output_ids"]], device="cuda").unsqueeze(0)print(token_ids.shape)with torch.inference_mode():model_outputs = model(token_ids)print(model_outputs[0].shape)real_logprobs = F.log_softmax(model_outputs[0] / TEMPERATURE, dim=-1)print(real_logprobs.shape)for i in range(len(output["logprobs"])):print("===", output["output_ids"][i], "===")# hook_logprobs = F.log_softmax(saved_outputs[i][seq_id] / TEMPERATURE, dim=-1)for key in output["logprobs"][i]:_real_logprobs = real_logprobs[0, i -1 + len(output["input_ids"])]vllm_rel_err = abs((output["logprobs"][i][key].logprob - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))# hook_rel_err = abs((hook_logprobs[key].item() - _real_logprobs[key].item()) / (_real_logprobs[key].item() + 1e-10))vllm_errs.append(vllm_rel_err)# hook_errs.append(hook_rel_err)vllm_prob = math.exp(output["logprobs"][i][key].logprob)# hook_prob = math.exp(hook_logprobs[key].item())real_prob = math.exp(_real_logprobs[key].item())vllm_prob_err = abs(vllm_prob - real_prob)# hook_prob_err = abs(hook_prob - real_prob)vllm_prob_errs.append(vllm_prob_err)# hook_prob_errs.append(hook_prob_err)if (vllm_rel_err > 0.1) and real_prob < 0.9:print(key, output["logprobs"][i][key],"HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item())print(f"Prob: {real_prob}, VLLM: {vllm_prob}")# if (vllm_rel_err > 0.1 or hook_rel_err > 0.1) and real_prob < 0.9:#     print(#         key, output["logprobs"][i][key],#         "HF logprobs:", real_logprobs[0, i -1 + len(output["input_ids"])][key].item(),#         "Hook logprobs:", hook_logprobs[key].item(),#     )#     print(f"Prob: {real_prob}, VLLM: {vllm_prob}, Hook: {hook_prob}")seq_id += 1from statistics import mean, stdev, medianprint("Relative logprob errors")print(f"VLLM: max={max(vllm_errs)}, mean={mean(vllm_errs)}, stdev={stdev(vllm_errs)}, median={median(vllm_errs)}, min={min(vllm_errs)}")# print(f"Hook: max={max(hook_errs)}, mean={mean(hook_errs)}, stdev={stdev(hook_errs)}, median={median(hook_errs)}, min={min(hook_errs)}")print("Absolute prob errors")print(f"VLLM: max={max(vllm_prob_errs)}, mean={mean(vllm_prob_errs)}, stdev={stdev(vllm_prob_errs)}, median={median(vllm_prob_errs)}, min={min(vllm_prob_errs)}")# print(f"Hook: max={max(hook_prob_errs)}, mean={mean(hook_prob_errs)}, stdev={stdev(hook_prob_errs)}, median={median(hook_prob_errs)}, min={min(hook_prob_errs)}")
http://www.xdnf.cn/news/17836.html

相关文章:

  • 数据结构:N叉树 (N-ary Tree)
  • Web 开发 15
  • 4.2 寻址方式 (答案见原书 P341)
  • CIAIE 2025上海汽车内外饰展观察:从美学到功能的产业跃迁
  • Tokenizer(切词器)的不同实现算法
  • 《软件工程导论》实验报告四 详细设计工具
  • 打靶日常-sql注入(手工+sqlmap)
  • 嵌入式学习 day52 IMX6ULL裸机开发-I2C
  • 功能组和功能组状态的概念关系和区别
  • Cursor/VSCode/VS2017 搭建Cocos2d-x环境,并进行正常的调试和运行(简单明了)
  • Docker的相关知识探究详解
  • Linux驱动学习day28(USB驱动,libusb操作)
  • RabbitMQ核心架构与应用
  • DeepSeek-V2:一种强大、经济且高效的混合专家语言模型
  • 区块链技术原理(13)-以太坊燃料费Gas
  • 【数据结构初阶】--排序(三):冒泡排序、快速排序
  • 旋钮键盘项目---foc讲解(开环)
  • 基于WSL搭建Ubuntu 22.04.x LTS开发环境
  • 102、【OS】【Nuttx】【周边】文档构建渲染:安装 Esbonio 服务器
  • Codeforces 无路可走
  • Git代码版本管理
  • 一文打通 AI 知识脉络:大语言模型等关键内容详解
  • Python基础-数据结构
  • 【部署K8S集群】 1、安装前环境准备配置
  • 重塑工业设备制造格局:明远智睿 T113-i 的破局之道
  • 基于多模型的零售销售预测实战指南
  • Spring IOC容器在Web环境中的启动奥秘:深入源码解析
  • 从 LLM 到自主 Agent:OpenCSG 打造开源 AgenticOps 生态
  • 云原生俱乐部-k8s知识点归纳(4)
  • EhViewer安卓ios全版本类下载安装工具的完整路径解析