当前位置: 首页 > news >正文

PPO和GRPO算法

        verl 是现在非常火的 rl 框架,而且已经支持了多个 rl 算法(ppo、grpo 等等)。

        过去对 rl 的理解很粗浅(只知道有好多个角色,有的更新权重,有的不更新),也曾硬着头皮看了一些论文和知乎,依然有很多细节不理解,现在准备跟着 verl 的代码梳理一遍两个著名的 rl 算法,毕竟代码不会隐藏任何细节!

        虽然 GRPO 算法是基于 PPO 算法改进来的,但是毕竟更简单,所以我先从 GRPO 的流程开始学习,然后再看 PPO。

GRPO 论文中的展示的总体流程:

论文中这张图主要展示了 GRPO 和 PPO 的区别,隐藏了其他的细节。

图中只能注意到以下几个关键点:

  • 没有 Value Model 和输出 v(value)

  • 同一个 q 得出了一组的 o(从 1 到 G)

  • 计算 A(Advantage) 的算法从 GAE 变成了 Group Computation

  • KL 散度计算不作用于 Reward Model,而是直接作用于 Policy Model

        其他细节看不懂,结合论文也依然比较抽象,因为我完全没有 RL 的知识基础,下文中我们结合代码会再一次尝试理解。

        下面是我根据 verl 代码自己 DIY 的流程图(帮助理解):

01 第一步:Rollout

        第一步是 rollout,rollout 是一个强化学习专用词汇,指的是从一个特定的状态按照某个策略进行一些列动作和状态转移。

        在 LLM 语境下,“某个策略”就是 actor model 的初始状态,“进行一些列动作”指的就是推理,即输入 prompt 输出 response 的过程。

verl/trainer/ppo/ray_trainer.py:

gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

        其背后的实现一般就是是 vllm 或 sglang 这些常见推理框架的离线推理功能,这部分功能相对独立我们先不展开。

权重同步

        一个值得注意的细节是代码里面的 rollout_sharding_manager 实现,它负责每一个大 step 结束后把刚刚训练好的 actor model 参数更新到 vllm 或 sglang。

        这样下一个大 step 的 rollout 采用的就是最新的模型权重(最新的策略)了。

        这是每一个大 step 里面真正要做的第一件事,在真正执行 rollout 之前。

        verl/workers/fsdp_workers.py:

class ActorRolloutRefWorker(Worker):   # ...    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)     def generate_sequences(self, prompts: DataProto):       # ...        with self.rollout_sharding_manager:            # ...            prompts = self.rollout_sharding_manager.preprocess_data(prompts)           output = self.rollout.generate_sequences(prompts=prompts)            output = self.rollout_sharding_manager.postprocess_data(output)

rollout_sharding_manager 的基类是 BaseShardingManager。

verl/workers/sharding_manager/base.py:

class BaseShardingManager:   def __enter__(self):        pass    def __exit__(self, exc_type, exc_value, traceback):        pass    def preprocess_data(self, data: DataProto) -> DataProto:        return data    def postprocess_data(self, data: DataProto) -> DataProto:        return data

  BaseShardingManager 的派生类在各自的 __enter__ 方法中实现了把 Actor Model 的权重 Sync 到 Rollout 实例的逻辑,以保证被 with self.rollout_sharding_manager 包裹的预处理和推理逻辑都是用的最新 Actor Model 权重。

推理 N 次

        此外,GRPO 算法要求对每一个 prompt 都生成多个 response,后续才能根据组间对比得出相对于平均的优势(Advantage)。

verl/trainer/config/ppo_trainer.yaml:

actor_rollout_ref:  rollout:    # number of responses (i.e. num sample times)   n: 1 # > 1 for grpo

        在 _build_rollout 的时候 actor_rollout_ref.rollout.n 被传给了 vLLMRollout 或其他的 Rollout 实现中,从而推理出 n 组 response。

verl/workers/fsdp_workers.py:

class ActorRolloutRefWorker(Worker):    def _build_rollout(self, trust_remote_code=False):        # ...        elif rollout_name == "vllm":            # ...            if vllm_mode == "customized":                rollout = vLLMRollout(                   actor_module=self.actor_module_fsdp,                                  config=self.config.rollout,                   tokenizer=self.tokenizer,                    
model_hf_config=self.actor_model_config,               )

02 第二步:计算 log prob

        log 是 logit,prob 是 probability,合起来就是对数概率,举一个简单的例子来说明什么是 log prob:

词表仅有 5 个词:    
<pad> (ID 0)    
你好 (ID 1)    
世界 (ID 2)   
! (ID 3)    
吗 (ID 4)
prompt:你好
prompt tokens: [1]
response:世界!
response tokens: [2,3]
模型前向传播得到完整的 logits 张量:
[    [-1.0, 0.5, 2.0, -0.5, -1.5],    // 表示 “你好” 后接 “世界” 概率最高,数值为 2.0    [-2.0, -1.0, 0.1, 3.0, 0.2]      // 表示 “你好世界” 后接 “!” 概率最高,数值为 3.0]
对每个 logit 计算 softmax 得到:
[    [-3.65, -2.15, -0.64, -3.15, -4.08],    [-4.34, -3.32, -2.20, -0.20, -2.10]]
提取实际 response 对应的数值:得到 log_probs:
[-0.64, -0.20]

总结下来:

  • 首先计算 prompt + response(来自 rollout)的完整 logits,即每一个 token 的概率分布

  • 截取 response 部分的 logits

  • 对每一个 logits 计算 log_sofmax(先 softmax,然后取对数),取出最终预测的 token 对应的 log_sofmax

  • 最终输出 old_log_probs, size = [batchsize, seq_len]

        此处你可能会有一个疑惑:在上一步 Rollout 的时候我们不是已经进行过完整 batch 的推理了么?

        为什么现在还要重复进行一次 forward 来计算 log_prob,而不是在 generate 的过程中就把 log_prob 保存下来?

答:因为 generate_sequences 阶段为了高效推理,不会保存每一个 token 的 log_prob,相反只关注整个序列的 log_prob。因此需要重新算一遍。

答:另外,vllm 官方 Q&A 中提到了 vllm 框架并不保证 log_probs 的稳定性。因为 pytorch 的 numerical instability 与 vllm 的并发批处理策略导致每一个 token 的 logits/log_probs 结果会略有不同,假如某一个 token 位采样了不同 token id,那么这个误差在后续还会被继续累加。我们在训练过程需要保证 log_probs 的稳定性,因此需要根据已经确定的 token id(即 response)再次 forward 一遍。

old log prob

verl/workers/fsdp_workers.py:

old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)

        指 Actor Model 对整个 batch 的数据(prompt + response)进行 forward 得到的 log_prob

        此处的 “old” 是相对于后续的 actor update 阶段,因为现在 actor model 还没有更新,所以依然采用的是旧策略 (ps:当前 step 的“旧策略”也是上一个大 step 的“新策略”)

ref log prob

verl/trainer/ppo/ray_trainer.py:

ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)

        指 Ref Model 对整个 batch 的数据(prompt + response)进行 forward 得到的 log_prob。

        通常 Ref Model 就是整个强化学习开始之前 Actor Model 最初的模样,换句话说第一个大 step 开始的时候 Actor Model == Ref Model,且 old_log_prob == ref_log_prob。

        Ref Model 的作用是在后续计算 policy loss 之前,计算 KL 散度并作用于 policy loss,目的是让 actor model 不要和最初的 ref model 相差太远。

03第三步:advantage

        advantage 是对一个策略的好坏最直接的评价,其背后就是 Reward Model,甚至也许不是一个 Model,而是一个粗暴的 function,甚至一个 sandbox 把 prompt+response 执行后得出的结果。

        在 verl 中允许使用上述多种 Reward 方案中的一种或多种,并把得出的 score 做合。

verl/trainer/ppo/ray_trainer.py:

# compute reward model score
if self.use_rm:    reward_tensor = self.rm_wg.compute_rm_score(batch)    batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:    future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
else:   reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

然后用这个 score 计算最终的 advantage。

verl/trainer/ppo/ray_trainer.py:

# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(    "norm_adv_by_std_in_grpo", True)  
# GRPO adv normalization factorbatch = compute_advantage(    batch,    
adv_estimator=self.config.algorithm.adv_estimator,   gamma=self.config.algorithm.gamma,    
lam=self.config.algorithm.lam,    
num_repeat=self.config.actor_rollout_ref.rollout.n,    norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,)

04第四步:actor update(小循环)

        在 PPOTrainer 中简单地一行调用,背后可是整个 GRPO 算法中最关键的步骤:

actor_output = self.actor_rollout_wg.update_actor(batch)

        在这里,会把上面提到的整个 batch 的数据再根据 actor_rollout_ref.actor.ppo_mini_batch_size 配置的值拆分成很多个 mini batch。

        然后对每一个 mini batch 数据进行一轮 forward + backward + optimize step,也就是小 step。

new log prob

        每一个小 step 中首先会对 mini batch 的数据计算(new)log_prob,第一个小 step 得到的值还是和 old_log_prob 一模一样的。

pg_loss

        然后通过输入所有 Group 的 Advantage 以新旧策略的概率比例(old_log_prob 和 log_prob),得出 pg_loss(Policy Gradient),这是最终用于 backward 的 policy loss 的基础部分。

        再次描述一下 pg_loss 的意义,即衡量当前策略(log_prob)相比于旧策略(old_log_prob),在当前优势函数(advantage)指导下的改进程度。

verl/workers/actor/dp_actor.py:

pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(    old_log_prob=old_log_prob,    
log_prob=log_prob,    
advantages=advantages,    
response_mask=response_mask,    
cliprange=clip_ratio,    
cliprange_low=clip_ratio_low,    
cliprange_high=clip_ratio_high,    
clip_ratio_c=clip_ratio_c,    
loss_agg_mode=loss_agg_mode,)

entropy loss

        entropy 指策略分布的熵 (Entropy):策略对选择下一个动作(在这里是下一个 token)的不确定性程度。

        熵越高,表示策略输出的概率分布越均匀,选择各个动作的概率越接近,策略的探索性越强;熵越低,表示策略越倾向于选择少数几个高概率的动作,确定性越强。

  entropy_loss 指 entropy 的 平均值,是一个标量,表示探索性高低。

verl/workers/actor/dp_actor.py:

if entropy_coeff != 0:   entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)   # compute policy loss    policy_loss = pg_loss - entropy_loss * entropy_coeff
else:   policy_loss = pg_loss

计算 KL 散度

        这里用到了前面 Ref Model 推出的 ref_log_prob,用这个来计算 KL 并作用于最后的 policy_loss,保证模型距离 Ref Model(初始的模型)偏差不会太大。

verl/workers/actor/dp_actor.py:

if self.config.use_kl_loss:    ref_log_prob = data["ref_log_prob"]   # compute kl loss    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type    )    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode    )    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef    metrics["actor/kl_loss"] = kl_loss.detach().item()    metrics["actor/kl_coef"] = self.config.kl_loss_coef

反向计算

verl/workers/actor/dp_actor.py:

loss.backward()

        持续循环小 step,直到遍历完所有的 mini batch,Actor Model 就完成了本轮的训练,会在下一个大 step 前把权重 sync 到 Rollout实例当中,准备处理下一个大 batch 数据。

http://www.xdnf.cn/news/789805.html

相关文章:

  • 大模型的外围关键技术
  • 【面试】音视频面试
  • 亮数据网页解锁器:让数据触手探索亮数据解锁工具:打破网页数据采集的局限
  • GPIO的内部结构与功能解析
  • Spring Boot Actuator未授权访问漏洞修复
  • RS232/RS485 光电隔离转换器DAM-3210A
  • 学习STC51单片机26(芯片为STC89C52RCRC)
  • Python训练营打卡Day42
  • Java-IO流之字节输入流详解
  • Spring AOP 和 AspectJ 有什么区别
  • Unity ARPG战斗系统 _ RootMotion相关知识点
  • 如何构建自适应架构的镜像
  • Diffusion Models: A Comprehensive Survey of Methods and Applications
  • 网络攻防技术七:计算机木马
  • Java高级 | 【实验二】控制器类+相关注解知识
  • InternLM2/LM2.5/ViT/VL1.5/VL2.0笔记: 核心点解析
  • 服装产品属性描述数据集(19197条),AI智能体知识库收集~
  • ULVAC DC-10-4P 400V input 10kW DC Pulse power supply 爱发科直流电源
  • ESOP股权管理平台完整解决方案
  • 基于LLaMA-Factory和Easy Dataset的Qwen3微调实战:从数据准备到LoRA微调推理评估的全流程指南
  • 开源模型应用落地-OpenAI Agents SDK-集成Qwen3-8B(一)
  • CDGP|数据治理:实现数据“可用不可见”“流通不流失”
  • [QMT量化交易小白入门]-六十、bt实现基于不同基准指数的量化策略回测
  • BFS进阶刷题
  • Spring 中如何开启事务?
  • 嵌入式学习笔记 - freeRTOS任务栈在初始化以及任务切换时的压栈出栈过程分析
  • 黑马程序员TypeScript课程笔记1(1-10)
  • 云开发实现新闻列表小程序
  • Cat.1与Cat.4区别及应用场景
  • QLora基础与进阶指南