当前位置: 首页 > java >正文

组相对策略优化(GRPO):原理及源码解析

文章目录

    • PPO vs GRPO
    • PPO的目标函数
    • GRPO的目标函数
      • KL散度约束与估计
      • ORM监督RL的结果
      • PRM监督RL的过程
      • 迭代RL
      • 算法流程
    • GRPO损失的不同版本
    • GRPO源码解析

  • DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models

PPO vs GRPO

在这里插入图片描述

PPO的目标函数

J P P O ( θ ) = E [ q ∼ P ( Q ) , o ∼ π θ old  ( O ∣ q ) ] 1 ∣ o ∣ ∑ t = 1 ∣ o ∣ min ⁡ [ π θ ( o t ∣ q , o < t ) π θ old  ( o t ∣ q , o < t ) A t , clip ⁡ ( π θ ( o t ∣ q , o < t ) π θ old  ( o t ∣ q , o < t ) , 1 − ε , 1 + ε ) A t ] \begin{align*} \mathcal{J}_{P P O}(\theta) &=\mathbb{E}\left[q \sim P(Q), o \sim \pi_{\theta_{\text {old }}}(O \mid q)\right]\\ &\frac{1}{|o|}\sum_{t=1}^{|o|} \min \left[\frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_t \mid q, o_{<t}\right)} A_t, \operatorname{clip}\left(\frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_t \mid q, o_{<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) A_t\right] \end{align*} JPPO(θ)=E[qP(Q),oπθold (Oq)]o1t=1omin[πθold (otq,o<t)πθ(otq,o<t)At,clip(πθold (otq,o<t)πθ(otq,o<t),1ε,1+ε)At]

A t A_t At是使用广义优势估计(GAE)基于奖励 { r ≥ t } \{r_{\ge t}\} {rt}和状态价值 V ψ V_{\psi} Vψ计算的优势值,需联合训练策略模型和状态价值模型。通常为避免奖励模型被过度拟合而产生异常输出,标准做法为每一个token的奖励添加策略模型和参考模型的KL惩罚。
r t = r φ ( q , o ≤ t ) − β log ⁡ π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) r_t=r_{\varphi}\left(q, o_{\leq t}\right)-\beta \log \frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{r e f}\left(o_t \mid q, o_{<t}\right)} rt=rφ(q,ot)βlogπref(otq,o<t)πθ(otq,o<t)

GRPO的目标函数

PPO算法使用价值模型输出作为优势的baseline,指导策略模型更新。价值模型一般与策略模型同尺寸,训练时占显存、耗算力。在LLM生成场景下,奖励函数给出整个response的分数,再加到最后一个token的奖励上,价值模型要预测token-level的奖励,比较困难。

GRPO通过对单个query采样多个response,取平均奖励作为baseline不需要使用价值模型(foregoes critic model),目标函数为:

J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ { min ⁡ [ π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ⁡ ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ] − β D K L [ π θ ∣ ∣ π r e f ] } \begin{align*} \mathcal{J}_{G R P O}(\theta) & =\mathbb{E}\left[q \sim P(Q),\left\{o_i\right\}_{i=1}^G \sim \pi_{\theta_{o l d}}(O \mid q)\right] \\ & \frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|}\left\{\min \left[\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)} \hat{A}_{i, t}, \operatorname{clip}\left(\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i, t}\right]-\beta \mathbb{D}_{K L}\left[\pi_\theta| | \pi_{r e f}\right]\right\} \end{align*} JGRPO(θ)=E[qP(Q),{oi}i=1Gπθold(Oq)]G1i=1Goi1t=1oi{min[πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)A^i,t,clip(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t),1ε,1+ε)A^i,t]βDKL[πθ∣∣πref]}

建立组内竞争机制,不需要外部独立的Critic。比组内平均分高的响应获得正分数,低的获得负分数,鼓励模型生成比平均水平更好的响应,使得平均得分越来越高。

KL散度约束与估计

KL散度项用于约束策略更新幅度,我们使用k3型的KL散度估计:
D K L [ π θ ∣ ∣ π r e f ] = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log ⁡ π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 \mathbb{D}_{K L}\left[\pi_\theta| | \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1 DKL[πθ∣∣πref]=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

解释: 奖励模型经比较/偏好数据集训练,使用相对优势的RL方法与奖励模型也比较匹配。PPO方法将策略模型和参考模型的KL散度作为奖励的惩罚,GRPO不惩罚奖励,而是将KL惩罚直接放在策略损失里面,避免在 A i , t A_{i,t} Ai,t中引入复杂的计算。

通常 x x x无法穷举,一般通过多次采样求平均方式估计期望,即无偏估计,KL散度的定义及无偏估计为
K L [ p ∣ ∣ q ] = ∑ x p ( x ) log ⁡ ( p ( x ) q ( x ) ) = E x ∼ p [ p ( x ) q ( x ) ] ≈ 1 N log ⁡ ( p ( x ) q ( x ) ) KL[p||q]=\sum_x p(x)\log\left(\dfrac{p(x)}{q(x)}\right)=\mathbb E_{x\sim p}\left[\frac{p(x)}{q(x)}\right]\approx\frac{1}{N}\log\left(\frac{p(x)}{q(x)}\right) KL[p∣∣q]=xp(x)log(q(x)p(x))=Exp[q(x)p(x)]N1log(q(x)p(x))

采样与期望: 如果p中有n个不同的x,从中随机采样m个x,m>>n,则重复x的个数除以m就近似为概率p(x)。

r = q ( x ) / p ( x ) r=q(x)/p(x) r=q(x)/p(x),几种KL散度采样估计:

  • k1 − log ⁡ r -\log r logr无偏、高方差,半数样本为负(KL为正),偏差比较高。
  • k2 1 2 ( log ⁡ r ) 2 \dfrac{1}{2}(\log r)^2 21(logr)2有偏、低方差,始终为正,明确反映出分布之间的偏离程度。
  • k3 − log ⁡ r + ( r − 1 ) -\log r + (r - 1) logr+(r1)无偏、低方差,始终为正。启发式设计,k1加上期望为0,并且与其负相关的项。
    • p ( x ) p(x) p(x) q ( x ) q(x) q(x)分步接近时, r r r的期望为1,新增项 r − 1 r-1 r1为0;
    • r r r增大,k1 − log ⁡ ( r ) -\log(r) log(r)减小,新增项 ( r − 1 ) (r-1) (r1)增加;
    • 直观表达, l o g ( p / q ) + ( q / p − 1 ) log(p/q)+(q/p-1) log(p/q)+(q/p1) p ( x ) p(x) p(x)大于 q ( x ) q(x) q(x)时,k1大于0,新增修正项小于1;

ORM监督RL的结果

对于每个query q q q,从 π θ o l d \pi_{\theta_{old}} πθold中采样一组输出 G = { o 1 , o 2 , ⋯ , o G } G=\{o_1,o_2,\cdots,o_{G}\} G={o1,o2,,oG},奖励模型对这些输出(或者说结果Outcome)打分 r = { r 1 , r 2 , ⋯ , r G } {\bf r}=\{r_1,r_2,\cdots,r_{G}\} r={r1,r2,,rG},将这些奖励标准化可作为每个输出 o i o_i oi在结束位置的组内相对优势
A ^ i , t = r ~ i = r i − mean ⁡ ( r ) std ⁡ ( r ) \hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})} A^i,t=r i=std(r)rimean(r)

PRM监督RL的过程

结果监督仅提供了每个输出在结束位置的奖励,不足以监督复杂的数学推理任务。

为进行过程监督,对每个推理步骤打分:
R = { { r 1 i n d e x ( 1 ) , ⋯ , r 1 i n d e x ( K 1 ) } , ⋯ , { r G i n d e x ( 1 ) , ⋯ , r G i n d e x ( K G ) } } \mathbf{R}=\left\{\left\{r_1^{{index}(1)}, \cdots, r_1^{{index}\left(K_1\right)}\right\}, \cdots,\left\{r_G^{{index}(1)}, \cdots, r_G^{{index}\left(K_G\right)}\right\}\right\} R={{r1index(1),,r1index(K1)},,{rGindex(1),,rGindex(KG)}}

其中 i n d e x ( j ) index(j) index(j)表示第 j j j步的结束token,标准化的步骤奖励为
r ~ i i n d e x ( j ) = r i i n d e x ( j ) − mean ⁡ ( R ) std ⁡ ( R ) \tilde{r}_i^{{index}(j)}=\frac{r_i^{{index}(j)}-\operatorname{mean}(\mathbf{R})}{\operatorname{std}(\mathbf{R})} r~iindex(j)=std(R)riindex(j)mean(R)

每一个token的优势等于之后所有步骤的标准化奖励和:
A ^ i , t = ∑ i n d e x ( j ) ≥ t r ~ i i n d e x ( j ) \hat A_{i,t}=\sum_{index(j)\ge t}\tilde r_i^{index(j)} A^i,t=index(j)tr~iindex(j)

迭代RL

随着策略模型更新,奖励模型可能不足以监督策略模型。GRPO使用迭代的方式,从新的策略模型中采样数据,加上10%的历史数据,以继续训练方式更新奖励模型。之后,将最新的策略模型设置为参考模型,继续训练策略模型,重复上述过程。

算法流程

在这里插入图片描述

奖励模型使用base模型初始化
奖励模型训练数据

GRPO损失的不同版本

GRPO目标可以定义为
L G R P O ( θ ) = − 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ l i , t , w . t . l i , t = π θ ( o i , t ∣ q , o i , < t ) [ π θ ( o i , t ∣ q , o i , < t ) ] n o g r a d A ^ i , t − β D K L [ π θ ∥ π r e f ] \mathcal{L}_{\mathrm{GRPO}}(\theta)=-\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|} l_{i, t}, \quad w.t.\ l_{i, t}=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\left[\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)\right]_{\mathrm{no} \mathrm{grad}}} \hat{A}_{i, t}-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi_\theta \| \pi_{\mathrm{ref}}\right] LGRPO(θ)=G1i=1Goi1t=1oili,t,w.t. li,t=[πθ(oi,tq,oi,<t)]nogradπθ(oi,tq,oi,<t)A^i,tβDKL[πθπref]

DAPO指出,GRPO使用sample-level损失,在long-COT场景下,long-response惩罚不足,导致其输出质量比较低。DAPO使用token-level损失,所有response中的每个token的奖励更加平衡,不受response长度的影响。
L D A P O ( θ ) = − 1 ∑ i = 1 G ∣ o i ∣ ∑ i = 1 G ∑ t = 1 ∣ o i ∣ l i , t \mathcal{L}_{\mathrm{DAPO}}(\theta)=-\frac{1}{\sum_{i=1}^G\left|o_i\right|} \sum_{i=1}^G \sum_{t=1}^{\left|o_i\right|} l_{i, t} LDAPO(θ)=i=1Goi1i=1Gt=1oili,t

Dr. GRPO指出,DAPO没有完全消除不同response长度偏差的影响,为了更彻底的消除,其使用常数替代序列长度:
L Dr. GRPO ( θ ) = − 1 L G ∑ i = 1 G ∑ t = 1 ∣ o i ∣ l i , t \mathcal{L}_{\text{Dr. GRPO}}(\theta) = -\frac{1}{LG} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} l_{i, t} LDr. GRPO(θ)=LG1i=1Gt=1oili,t

GRPO源码解析

代码库trl中GRPOTrainer的实现,继承于Transformers Trainer,重载_prepare_inputscompute_loss方法

源码在这里:https://github.com/huggingface/trl/blob/v0.18.1/trl/trainer/grpo_trainer.py

算法过程

  1. 构造批次输入prompts
    • 使用自定义的RepeatSampler采样批次,保证每个prompt能重复采样多次,并且能跨进程同步分组;
    • 风格为generatechat_completions,执行左padding,左truncate;
  2. 采样completions_prepare_inputs中调用_generate_and_score_completions,参数为temperature=0.9top_p=1.0max_new_tokens=256
    • 若使用vllm server:
      • 权重同步:确保policy model和vllm model的参数同步;
      • 数据并行采样:主进程上gather所有进程上的prompts,为每个不重复的prompt生成num_generations个completions;
      • 广播分配:主进程上broadcast所有completions到其它进程,所有进程截取自己prompts的completions;
    • 若使用transformers标准的model.generate:
      • 独立生成每个prompt的completion,包含重复的prompt(同一prompt多次prefill),计算低效;
  3. 处理completion padding
    • 根据completion中EOS的位置计算completion长度,并mask首个EOS后的token,只保留有效的completion token;
    • mask所有没有EOS的completion,避免异常completion对loss影响过大(可选);
  4. 计算old_logprobs:若使用相同completion多次迭代优化,计算当前policy model的logprobs作为old_logprobs,用于后续epoch中计算概率比率;
  5. 计算scores:每个reward model/reward func计算每条prompt+completion的score并加权,得到每条sentence的score;
  6. 计算advantages:gather所有进程上的scores,分组标准化,即奖励 - 奖励均值 / 奖励标准差(可选)
  7. 计算loss
    • 计算policy model的logprobs;
    • 计算reference model的ref_logprobs;
    • 计算policy model和reference model之间在每个completion token的kl散度,使用k3无偏估计:kl=log(p/q)+(q/p-1),如果p和q都是对数概率,则kl=p-q+exp(q-p)-1,即kl损失
    • 使用logprobs和old_logprobs计算概率比率并裁剪,限制参数更新幅度(重要性采样,PPO算法的核心),利用裁剪后概率比率clamped_ratio、advantage和completion mask,计算每个token的策略损失
    • 损失加权求和:加权求和token-level的策略损失和kl损失,kl损失权重小,非主导;
    • 损失均值化:loss有多种求和/平均方式,bnpo loss不考虑每条样本的completion长度的影响,取所有token的平均loss。grpo_loss对每条completion依次在token-level、sample-level上求和平均,对长completion的惩罚不足;
    • 使用梯度下降更新policy model;
http://www.xdnf.cn/news/10671.html

相关文章:

  • Nginx + Tomcat负载均衡群集
  • VBA 64位API声明语句第010讲
  • Nginx+Tomcat负载均衡集群
  • 数据挖掘顶刊《IEEE Transactions on Knowledge and Data Engineering》2025年5月研究热点都有些什么?
  • 2025年06月03日Github流行趋势
  • 金融中的线性优化:投资组合分配与求解器 - Part 2
  • TDengine 高级功能——流计算
  • 开源量子模拟引擎:Quantum ESPRESSO本地部署教程,第一性原理计算轻松入门!
  • PostgreSQL数据库备份
  • 【Oracle】视图
  • 3. 简述node.js特性与底层原理
  • 基于Halcon深度学习之分类
  • SpringBoot系列之RabbitMQ 实现订单超时未支付自动关闭功能
  • AI+3D 视觉重塑塑料袋拆垛新范式:迁移科技解锁工业自动化新高度
  • Neo4j 数据导入:原理、技术、技巧与最佳实践
  • 深入理解Android进程间通信机制
  • uniapp 开发企业微信小程序,如何区别生产环境和测试环境?来处理不同的服务请求
  • SOC-ESP32S3部分:28-BLE低功耗蓝牙
  • Rust 学习笔记:使用自定义命令扩展 Cargo
  • 8.RV1126-OPENCV 视频中添加LOGO
  • 鸿蒙生态再添翼:身份证银行卡识别引领智能识别技术新篇章
  • Python数据可视化科技图表绘制系列教程(一)
  • 20250603在荣品的PRO-RK3566开发板的Android13下的命令行查看RK3566的温度
  • MS1023/MS1224——10MHz 到 80MHz、10:1 LVDS 并串转换器(串化器)/串并转换器(解串器)
  • 深度解析 Qt 最顶层类 QObject:继承关系与内存生命周期管理
  • ERP、OA、CRM三个企业管理软件的区别与联系
  • # [特殊字符] Unity UI 性能优化终极指南 — LayoutGroup篇
  • 微软推出 Bing Video Creator,免费助力用户轻松创作 AI 视频
  • 03.搭建K8S集群
  • 【计算机网络 第8版】谢希仁编著 第六章应用层 题型总结1 编码