当前位置: 首页 > news >正文

GRPO:利用组内平均奖励来计算优势

GRPO:利用组内平均奖励来计算优势

TL; DR:GRPO 提出将 PPO 中结合价值网络估计 GAE 的优势函数计算方法改为利用组内平均奖励直接计算,降低了 RLHF 的整体训练开销。在 KL 约束的位置和估计方法等细节处也进行了一些改进。

从PPO到GRPO

在训练好 reward model 之后,原始 RLHF 使用 PPO 算法来进行强化学习训练,在标准 PPO 的基础上,RLHF 还在 reward 内加上了一个的 KL 散度作为正则项。用来约束当前模型与参考模型(即上一阶段 SFT 之后的模型)输出分布的差异不要太大。整体上,RLHF PPO 中,回答的第 t t t 个 token 对应的奖励 r t r_t rt 表示为:

r t = r ϕ ( q , o ≤ t ) − β log ⁡ π θ ( o ∣ q , o < t ) π ref ( o ∣ q , o < t ) r_t=r_\phi(q,o_{\le t})-\beta\log\frac{\pi_\theta(o|q,o_{<t})}{\pi_\text{ref}(o|q,o_{<t})} \notag \\ rt=rϕ(q,ot)βlogπref(oq,o<t)πθ(oq,o<t)

PPO 中使用 GAE 来估计优势函数 A t A_t At,我们之前介绍过,GAE 权衡了采用真实采样累积奖励 { r ≥ t } \{r_{\ge t}\} {rt} 的无偏高方差,和采用价值网络估计价值函数 V ψ V_\psi Vψ 的有偏低方差,是目前主流的优势函数形式。

最终,RLHF PPO 的目标函数为:
J PPO ( θ ) = E q ∼ D prompt , o ∼ π θ old 1 ∣ o ∣ ∑ t = 1 ∣ o ∣ min ⁡ [ π θ ( o t ∣ q , o < t ) π θ old ( o t ∣ q , o < t ) A t , clip ( π θ ( o t ∣ q , o < t ) π θ old ( o t ∣ q , o < t ) , 1 − ϵ , 1 + ϵ ) A t ] , 其中 A t = GAE ( r > t , V ψ ) \mathcal{J}_\text{PPO}(\theta)=\mathbb{E}_{q\sim D_\text{prompt},o\sim\pi_{\theta_\text{old}}}\frac{1}{|o|}\sum_{t=1}^{|o|}\min\left[\frac{\pi_\theta(o_t|q,o_{<t})}{\pi_{\theta_\text{old}}(o_t|q,o_{<t})}A_t,\text{clip}\left(\frac{\pi_\theta(o_t|q,o_{<t})}{\pi_{\theta_\text{old}}(o_t|q,o_{<t})},1-\epsilon,1+\epsilon\right)A_t\right],\\其中\quad A_t=\text{GAE}(r_{>t},V_\psi) \notag \\ JPPO(θ)=EqDprompt,oπθoldo1t=1omin[πθold(otq,o<t)πθ(otq,o<t)At,clip(πθold(otq,o<t)πθ(otq,o<t),1ϵ,1+ϵ)At],其中At=GAE(r>t,Vψ)

由于 GAE 需要综合 r > t r_{>t} r>t V ψ V_\psi Vψ 来计算优势函数,因此价值网络 V ψ V_\psi Vψ 的训练是必不可少的。 V ψ V_\psi Vψ 一般是一个与 policy π θ \pi_\theta πθ 参数量相当的价值估计模型,这样整个 PPO 算法需要运行四个模型: π θ \pi_\theta πθ V ψ V_\psi Vψ π ref \pi_\text{ref} πref r ϕ r_\phi rϕ,其中前两者需要训练更新参数,这使得训练的内存和计算开销非常大。

另外,从 RL 算法的角度理解,训练中估计的价值函数 V ψ V_\psi Vψ 是作为 baseline 来计算 GAE 优势函数的,是为了降低方差。但是在 LLM RL 语境下,一般 reward model 只对最后一个 token 打分,给出 reward。GAE 这样的优势函数形式计算太复杂, V ψ V_\psi Vψ 很难在每个 token上都给出准确的估计。

针对上述问题,GRPO 提出不再使用估计的价值函数 V ψ V_\psi Vψ ,而是直接使用对于同一问题的一组多个不同回答的平均奖励,作为 baseline。具体来说,对于每个问题 q q q,使用 π θ old \pi_{\theta_\text{old}} πθold 采样一组 G G G 个回答 { o i } i = 1 G \{o_i\}_{i=1}^G {oi}i=1G,GRPO 的目标函数为:

J GRPO ( θ ) = E q , { o i } i = 1 G 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ { min ⁡ [ π θ ( o i , t ∣ q , o i , < t ) π θ old ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ old ( o i , t ∣ q , o i , < t ) , 1 − ϵ , 1 + ϵ ) A ^ i , t ] − β D KL [ π θ ∣ ∣ π ref ] } \mathcal{J}_\text{GRPO}(\theta)=\mathbb{E}_{q,\ \{o_i\}_{i=1}^G}\frac{1}{G}\sum_{i=1}^G\frac{1}{|o_i|}\sum_{t=1}^{|o_i|}\left\{\min\left[\frac{\pi_\theta(o_{i,t}|q,o_{i,<t})}{\pi_{\theta_\text{old}}(o_{i,t}|q,o_{i,<t})}\hat{A}_{i,t},\text{clip}\left(\frac{\pi_\theta(o_{i,t}|q,o_{i,<t})}{\pi_{\theta_\text{old}}(o_{i,t}|q,o_{i,<t})},1-\epsilon,1+\epsilon\right)\hat{A}_{i,t}\right]-\beta D_\text{KL}[\pi_\theta||\pi_\text{ref}]\right \} \notag \\ JGRPO(θ)=Eq, {oi}i=1GG1i=1Goi1t=1oi{min[πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)A^i,t,clip(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)1ϵ,1+ϵ)A^i,t]βDKL[πθ∣∣πref]}
这里的 A ^ i , t \hat{A}_{i,t} A^i,t 是基于组内多个回答的相对奖励计算出的优势函数,其具体的计算形式,我们下面会详细介绍。

这样的好处除了降低训练的计算复杂度之外,还有一点是比较适配于这种 pairwise 形式训练出的 reward model,都是对同一个问题,不同回答的对比打分。

GRPO 还有一点改进是把 RLHF PPO 中的在 reward 中计算 π θ \pi_\theta πθ π ref \pi_\text{ref} πref KL 散度,改为了直接将 KL 约束加到损失函数中,进一步简化了优势函数的计算。并且 KL 散度估计的形式也改成了无偏但是方差更低的 k3 估计:
D KL [ π θ ∣ ∣ π ref ] = π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log ⁡ π ref ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 D_\text{KL}[\pi_\theta||\pi_\text{ref}]=\frac{\pi_\text{ref}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})}-\log\frac{\pi_\text{ref}(o_{i,t}|q,o_{i,<t})}{\pi_\theta(o_{i,t}|q,o_{i,<t})}-1 \notag \\ DKL[πθ∣∣πref]=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

原文中这个图比较清晰的展现出了 GRPO 相对于 PPO 的改动,主要就是优势函数的计算方式和 KL 约束的位置。

在这里插入图片描述

GRPO 优势计算

对于每个问题 q q q,使用老策略模型 π θ old \pi_{\theta_\text{old}} πθold 生成一组 G G G 个回答 { o 1 , o 2 , … , o G } \{o_1,o_2,\dots,o_G\} {o1,o2,,oG}。使用 outcome supervise(结果监督)和 precesses supervise(过程监督)时,优势函数 A ^ i , t \hat{A}_{i,t} A^i,t 的具体计算分别如下。

对于结果监督,对每个 { q , o } \{q,o\} {q,o} 进行打分,得到一组对应的 rewards r = { r 1 , r 2 , … , r G } \mathbf{r}=\{r_1,r_2,\dots,r_G\} r={r1,r2,,rG}。接下来对这些 rewards 进行归一化:减去组均值,除以组标准差:
A ^ i , t = r ~ i = r i − mean ( r ) std ( r ) \hat{A}_{i,t}=\tilde{r}_i=\frac{r_i-\text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})} \notag \\ A^i,t=r~i=std(r)rimean(r)
对于过程监督,则对一组 G G G 个回答给出的打分为 R = { { r 1 index ( 1 ) , … , r 1 index ( K 1 ) } , … , { r G index ( 1 ) , … , r G index ( K G ) } } \mathbf{R}=\{\{r_1^{\text{index}(1)},\dots,r_1^{\text{index}(K_1)}\},\dots,\{r_G^{\text{index}(1)},\dots,r_G^{\text{index}(K_G)}\}\} R={{r1index(1),,r1index(K1)},,{rGindex(1),,rGindex(KG)}},其中 index ( j ) \text{index}(j) index(j) 表示第 j j j 步最后一个 token 的位置索引, K i K_i Ki 表示第 i i i 个回答的总步数。然后也是减均值除标准差,进行归一化,得到当前步的奖励:
r ~ i index ( j ) = r i index ( j ) − mean ( R ) std ( R ) \tilde{r}_i^{\text{index}(j)}=\frac{r_i^{\text{index}(j)}-\text{mean}(\mathbf{R})}{\text{std}({\mathbf{R})}} \notag \\ r~iindex(j)=std(R)riindex(j)mean(R)
然后计算每个 token 其后的归一化奖励的累加和,即可得到优势:
A ^ i , t = ∑ index ( j ) ≥ t r ~ i index ( j ) \hat{A}_{i,t}=\sum_{\text{index}(j)\ge t}\tilde{r}_i^{\text{index}(j)} \notag \\ A^i,t=index(j)tr~iindex(j)

总结

GRPO 在 DeepSeek Math 中就提出了,在 R1 火爆出圈后,得到了大家的广泛关注和应用。在 R1 爆火后的一段时间里,大家进行 RLHF 训练的主流算法都切换到了 GRPO。具体算法方案上,GRPO 利用组内平均奖励来计算优势,相比于 PPO 需要额外训练一个 value model 的方式,训练开销的降低是肯定的,但是个人认为从效果天花板来看,不一定能比 PPO 更强。

http://www.xdnf.cn/news/248977.html

相关文章:

  • 蓝桥杯Python案例
  • 计算机组成原理实验(5) 堆栈寄存器实验
  • 2025五一杯数学建模ABC题赛题已出
  • ctfshow web入门 web44
  • Python学习笔记(第一部分)
  • 基于深度学习的人脸属性识别算法研究
  • 随机森林实战:从原理到垃圾邮件分类
  • 超稳定性理论
  • 第十四章:生产之路:LLM 应用部署、运维与优化
  • MOOS-ivp使用(一)——水下机器人系统的入门与使用
  • 【2025最新面经】暑期实习常问知识点
  • 前端面经 4
  • 【C++学习笔记】深入理解虚函数和多态
  • 简单句练习--语法基础
  • 50、【OS】【Nuttx】【OSTest】参数解析:函数定义
  • 当算力遇上堵车:AI如何让城市血管不再“血栓”?
  • OpenStack Yoga版安装笔记(25)Nova Cell理解
  • 黑马Java基础笔记-6
  • 伽利略如何测量光速?一场跨越山头的失败实验
  • VBA数据结构深度解析:基础类型、自定义类型与数组操作指南
  • Dagster资产工厂实战:从Python到YAML配置的高效ETL流程
  • 408真题笔记
  • 第十三章:LLM 应用质量保证:评估体系、工具与实战
  • 深入解析三大查找算法:线性查找、二分查找与哈希查找的原理与应用
  • 进程(Process)和操作系统(Operation System)
  • ctfshow web入门 web46
  • 用spring-boot-maven-plugin打包成单个jar有哪些缺点优化方案
  • pandas读取Excel数据(.xlsx和.xls)到treeview
  • JavaScript如何实现类型判断?
  • C语言 指针(2)