当前位置: 首页 > news >正文

huggingface TRL中是怎么获取参考模型的输出的

huggingface TRL中是怎么获取参考模型的输出的

reference_logps 的计算流程 = 让参考模型(ref_model)或禁用了 Adapter 的 LoRA 模型把「prompt + response」完整序列再跑一遍前向 → 取 response 部分每个 token 的对数概率 → 累加得到整句 log p(response|prompt),并缓存到 dataset 的 "reference_logps" 列中。


✅ 具体实现链路(顺着代码读)

1️⃣ 计算入口

compute_reference_log_probs() 负责 单批次 的参考模型 log-prob 计算。这个方法通常在训练循环开始时被调用,用于初始化参考模型的输出基准值。具体执行时机取决于是否设置了precompute_ref_log_probs参数。

2️⃣ 选择“参考模型”
  • 如果 显式传入了 ref_model(完整权重),就用它。这种情况适用于显式提供独立参考模型的场景,例如使用经过SFT微调的模型作为参考模型。
  • 如果 用的是 LoRA(Peft),则通过 null_ref_context() 临时 关闭 Adapter → 相当于拿到"原始基础模型"。这里会保存当前LoRA状态,临时禁用所有Adapter,计算完成后再恢复原状。
  • 如果 事前已经离线算好precompute_ref_log_probs=True),就直接跳过,不再重复计算。这在大型数据集上可以显著节省计算时间。
3️⃣ 前向 & 取 logits
with torch.no_grad():  # 禁用梯度计算以节省内存if is_encoder_decoder:  # 处理encoder-decoder架构模型logits = ref_model(prompt_input_ids,  # encoder输入attention_mask=prompt_attention_mask,  # encoder maskdecoder_input_ids=decoder_input_ids,  # decoder输入labels=completion_labels  # 用于计算loss的标签).logits  # 获取输出logitselse:  # 处理decoder-only架构模型logits = ref_model(completion_input_ids,  # 完整输入序列attention_mask=completion_attention_mask  # 完整attention mask).logits
4️⃣ token 级 log-prob → 句级 log-prob

调用静态方法 get_batch_logps(),内部处理逻辑:

  1. 模型架构处理:
    • 对 decoder-only 模型:执行logits右移1位(使用logits[:-1]预测labels[1:])
    • encoder-decoder:直接使用原始labels作为target
  2. 概率计算:
    • 使用log_softmax将logits转换为对数概率
    • 通过gather操作提取对应token的概率
  3. 掩码处理:
    • 使用attention_mask过滤掉padding部分
    • 确保只计算response部分的概率
  4. 聚合计算:
    • 对有效token的对数概率进行求和
    • 可选地计算平均对数概率(当average_log_prob=True时)
5️⃣ 缓存到 dataset
  • 首次加载:训练集/验证集第一次被get_train_dataloader()/get_eval_dataloader()调用时:
    • 遍历整个dataloader的所有批次
    • 批量计算reference_logpsreference_KL_logps
    • 使用dataset.add_column()将结果存入数据集
  • 后续使用:直接从dataset列读取缓存值,避免重复计算

✅ 关键代码片段

# 1. 计算单条 log-prob
completion_logps = self.get_batch_logps(completion_logits,  # 模型输出的logitspadded_batch["completion_labels"],  # 目标token IDsaverage_log_prob=False,  # 返回总和而非平均值label_pad_token_id=self.label_pad_token_id,  # 填充token ID...
)# 2. 存列
self.train_dataset = self.train_dataset.add_column(name="reference_logps",  # 列名column=torch.cat(reference_completion_logps).float().numpy()  # 数值数据
)

✅ 总结

步骤实现位置说明技术细节
选参考模型__init__ & null_ref_context显式ref/LoRA关adapter使用peft.utils.get_peft_model_state_dict保存/恢复状态
前向compute_reference_log_probs得到prompt+response的logits支持encoder-decoder和decoder-only两种架构
token→句logpget_batch_logpsshift、mask、sum使用torch.nn.functional.log_softmaxgather
缓存get_train/eval_dataloader一次性算完,dataset加列使用HuggingFace Dataset的add_column方法

这种设计使得整个训练过程只需在初始化阶段跑一次参考模型,后续全部复用缓存值,既节省显存(避免同时加载两个模型)又提升训练效率(避免重复计算)。在典型RLHF训练中,这可以节省30%-50%的训练时间。## huggingface TRL中是怎么获取参考模型的输出的
reference_logps 的计算流程 = 让参考模型(ref_model)或禁用了 Adapter 的 LoRA 模型把「prompt + response」完整序列再跑一遍前向 → 取 response 部分每个 token 的对数概率 → 累加得到整句 log p(response|prompt),并缓存到 dataset 的 "reference_logps" 列中。


✅ 具体实现链路(顺着代码读)

1️⃣ 计算入口

compute_reference_log_probs() 负责 单批次 的参考模型 log-prob 计算。

2️⃣ 选择“参考模型”
  • 如果 显式传入了 ref_model(完整权重),就用它。
  • 如果 用的是 LoRA(Peft),则通过 null_ref_context() 临时 关闭 Adapter → 相当于拿到“原始基础模型”。
  • 如果 事前已经离线算好precompute_ref_log_probs=True),就直接跳过,不再重复计算。
3️⃣ 前向 & 取 logits
with torch.no_grad():if is_encoder_decoder:logits = ref_model(prompt_input_ids,attention_mask=prompt_attention_mask,decoder_input_ids=decoder_input_ids,labels=completion_labels).logitselse:logits = ref_model(completion_input_ids, attention_mask=completion_attention_mask).logits
4️⃣ token 级 log-prob → 句级 log-prob

调用静态方法 get_batch_logps(),内部:

  • 对 decoder-only 模型:shift 1 位(logits[:-1] vs labels[1:])
  • encoder-decoder:labels 本身即为 response,无需额外 shift
  • log_softmax + gather 拿到每个 token 的 log p,再 mask 掉 prompt 与 padding,最后 按 response 长度求和
5️⃣ 缓存到 dataset
  • 训练集 / 验证集第一次被 get_train_dataloader() / get_eval_dataloader() 调用时,会遍历整个 dataloader,把算好的 reference_logpsreference_KL_logpsdataset.add_column() 存进去。
  • 以后再训练/评估时直接读取,不再重复推理参考模型,极大加速。

✅ 关键代码片段

# 1. 计算单条 log-prob
completion_logps = self.get_batch_logps(completion_logits,padded_batch["completion_labels"],average_log_prob=False,...
)# 2. 存列
self.train_dataset = self.train_dataset.add_column(name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
)

✅ 总结

步骤实现位置说明
选参考模型__init__ & null_ref_context显式 ref / LoRA 关 adapter
前向compute_reference_log_probs得到 prompt+response 的 logits
token→句 logpget_batch_logpsshift、mask、sum
缓存get_train/eval_dataloader一次性算完,dataset 加列

这样,整个训练过程 只需跑一次参考模型,后续全部复用缓存值,既省显存又省时间。

http://www.xdnf.cn/news/1311913.html

相关文章:

  • Swift 实战:实现一个简化版的 Twitter(LeetCode 355)
  • 新手向:GitCode疑难问题诊疗
  • Java 10 新特性及具体应用
  • 嵌入式硬件篇---电感串并联
  • 2^{-53} 单位舍入误差、机器精度、舍入的最大相对误差界限
  • 实例分割-动手学计算机视觉13
  • docker安装mongodb及java连接实战
  • Effective C++ 条款45:运用成员函数模板接受所有兼容类型
  • Linux怎么查看服务器开放和启用的端口
  • 【原理】C# 字段、属性对比及其底层实现
  • illustrator插件大全 免费插件介绍 Ai设计插件集合 (3)
  • Python语言一键整理xhs评论 基于github的开源项目 MediaCrawler
  • Linux进程概念(四)环境地址变量
  • 同创物流学习记录2·电车
  • 链式二叉树的基本操作——遍历
  • 实时计算 记录
  • 美国服务器环境下Windows容器工作负载基于指标的自动扩缩
  • 从盲区到全域:黎阳之光视频孪生+AI智能算法驱动智慧机场三维感知革命
  • 4.6 Vue 3 中的模板引用 (Template Refs)
  • CSS复习
  • Jenkins安装部署(Win11)和常见配置镜像加速
  • SysTick寄存器(嘀嗒定时器实现延时)
  • 要导入StandardScaler类进行数据标准化,请使用以下语句:
  • VS Code配置MinGW64编译ALGLIB库
  • 《C语言程序设计》笔记p10
  • 【数据分享】上市公司供应链成本分摊数据(2007-2024)
  • 【数据结构】-2- 泛型
  • leetcodehot100 矩阵置零
  • 基于Spring Boot 4s店车辆管理系统 租车管理系统 停车位管理系统 智慧车辆管理系统
  • 谷歌手机刷机和面具ROOT保姆级别教程