huggingface TRL中是怎么获取参考模型的输出的
huggingface TRL中是怎么获取参考模型的输出的
reference_logps
的计算流程 = 让参考模型(ref_model)或禁用了 Adapter 的 LoRA 模型把「prompt + response」完整序列再跑一遍前向 → 取 response 部分每个 token 的对数概率 → 累加得到整句 log p(response|prompt),并缓存到 dataset 的 "reference_logps"
列中。
✅ 具体实现链路(顺着代码读)
1️⃣ 计算入口
compute_reference_log_probs()
负责 单批次 的参考模型 log-prob 计算。这个方法通常在训练循环开始时被调用,用于初始化参考模型的输出基准值。具体执行时机取决于是否设置了precompute_ref_log_probs
参数。
2️⃣ 选择“参考模型”
- 如果 显式传入了
ref_model
(完整权重),就用它。这种情况适用于显式提供独立参考模型的场景,例如使用经过SFT微调的模型作为参考模型。 - 如果 用的是 LoRA(Peft),则通过
null_ref_context()
临时 关闭 Adapter → 相当于拿到"原始基础模型"。这里会保存当前LoRA状态,临时禁用所有Adapter,计算完成后再恢复原状。 - 如果 事前已经离线算好(
precompute_ref_log_probs=True
),就直接跳过,不再重复计算。这在大型数据集上可以显著节省计算时间。
3️⃣ 前向 & 取 logits
with torch.no_grad(): # 禁用梯度计算以节省内存if is_encoder_decoder: # 处理encoder-decoder架构模型logits = ref_model(prompt_input_ids, # encoder输入attention_mask=prompt_attention_mask, # encoder maskdecoder_input_ids=decoder_input_ids, # decoder输入labels=completion_labels # 用于计算loss的标签).logits # 获取输出logitselse: # 处理decoder-only架构模型logits = ref_model(completion_input_ids, # 完整输入序列attention_mask=completion_attention_mask # 完整attention mask).logits
4️⃣ token 级 log-prob → 句级 log-prob
调用静态方法 get_batch_logps()
,内部处理逻辑:
- 模型架构处理:
- 对 decoder-only 模型:执行logits右移1位(使用logits[:-1]预测labels[1:])
- encoder-decoder:直接使用原始labels作为target
- 概率计算:
- 使用
log_softmax
将logits转换为对数概率 - 通过
gather
操作提取对应token的概率
- 使用
- 掩码处理:
- 使用
attention_mask
过滤掉padding部分 - 确保只计算response部分的概率
- 使用
- 聚合计算:
- 对有效token的对数概率进行求和
- 可选地计算平均对数概率(当
average_log_prob=True
时)
5️⃣ 缓存到 dataset
- 首次加载:训练集/验证集第一次被
get_train_dataloader()
/get_eval_dataloader()
调用时:- 遍历整个dataloader的所有批次
- 批量计算
reference_logps
和reference_KL_logps
- 使用
dataset.add_column()
将结果存入数据集
- 后续使用:直接从dataset列读取缓存值,避免重复计算
✅ 关键代码片段
# 1. 计算单条 log-prob
completion_logps = self.get_batch_logps(completion_logits, # 模型输出的logitspadded_batch["completion_labels"], # 目标token IDsaverage_log_prob=False, # 返回总和而非平均值label_pad_token_id=self.label_pad_token_id, # 填充token ID...
)# 2. 存列
self.train_dataset = self.train_dataset.add_column(name="reference_logps", # 列名column=torch.cat(reference_completion_logps).float().numpy() # 数值数据
)
✅ 总结
步骤 | 实现位置 | 说明 | 技术细节 |
---|---|---|---|
选参考模型 | __init__ & null_ref_context | 显式ref/LoRA关adapter | 使用peft.utils.get_peft_model_state_dict 保存/恢复状态 |
前向 | compute_reference_log_probs | 得到prompt+response的logits | 支持encoder-decoder和decoder-only两种架构 |
token→句logp | get_batch_logps | shift、mask、sum | 使用torch.nn.functional.log_softmax 和gather |
缓存 | get_train/eval_dataloader | 一次性算完,dataset加列 | 使用HuggingFace Dataset的add_column 方法 |
这种设计使得整个训练过程只需在初始化阶段跑一次参考模型,后续全部复用缓存值,既节省显存(避免同时加载两个模型)又提升训练效率(避免重复计算)。在典型RLHF训练中,这可以节省30%-50%的训练时间。## huggingface TRL中是怎么获取参考模型的输出的
reference_logps
的计算流程 = 让参考模型(ref_model)或禁用了 Adapter 的 LoRA 模型把「prompt + response」完整序列再跑一遍前向 → 取 response 部分每个 token 的对数概率 → 累加得到整句 log p(response|prompt),并缓存到 dataset 的 "reference_logps"
列中。
✅ 具体实现链路(顺着代码读)
1️⃣ 计算入口
compute_reference_log_probs()
负责 单批次 的参考模型 log-prob 计算。
2️⃣ 选择“参考模型”
- 如果 显式传入了
ref_model
(完整权重),就用它。 - 如果 用的是 LoRA(Peft),则通过
null_ref_context()
临时 关闭 Adapter → 相当于拿到“原始基础模型”。 - 如果 事前已经离线算好(
precompute_ref_log_probs=True
),就直接跳过,不再重复计算。
3️⃣ 前向 & 取 logits
with torch.no_grad():if is_encoder_decoder:logits = ref_model(prompt_input_ids,attention_mask=prompt_attention_mask,decoder_input_ids=decoder_input_ids,labels=completion_labels).logitselse:logits = ref_model(completion_input_ids, attention_mask=completion_attention_mask).logits
4️⃣ token 级 log-prob → 句级 log-prob
调用静态方法 get_batch_logps()
,内部:
- 对 decoder-only 模型:shift 1 位(logits[:-1] vs labels[1:])
- encoder-decoder:labels 本身即为 response,无需额外 shift
- 用
log_softmax + gather
拿到每个 token 的 log p,再 mask 掉 prompt 与 padding,最后 按 response 长度求和。
5️⃣ 缓存到 dataset
- 训练集 / 验证集第一次被
get_train_dataloader()
/get_eval_dataloader()
调用时,会遍历整个 dataloader,把算好的reference_logps
与reference_KL_logps
用dataset.add_column()
存进去。 - 以后再训练/评估时直接读取,不再重复推理参考模型,极大加速。
✅ 关键代码片段
# 1. 计算单条 log-prob
completion_logps = self.get_batch_logps(completion_logits,padded_batch["completion_labels"],average_log_prob=False,...
)# 2. 存列
self.train_dataset = self.train_dataset.add_column(name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy()
)
✅ 总结
步骤 | 实现位置 | 说明 |
---|---|---|
选参考模型 | __init__ & null_ref_context | 显式 ref / LoRA 关 adapter |
前向 | compute_reference_log_probs | 得到 prompt+response 的 logits |
token→句 logp | get_batch_logps | shift、mask、sum |
缓存 | get_train/eval_dataloader | 一次性算完,dataset 加列 |
这样,整个训练过程 只需跑一次参考模型,后续全部复用缓存值,既省显存又省时间。