当前位置: 首页 > news >正文

ReMax:将贪婪采样的 reward 作为 baseline

ReMax:将贪婪采样的 reward 作为 baseline

TL; DR:在经典的 REINFORCE 算法的基础上,取每个问题贪婪采样结果的 reward 作为 baseline,以降低方差。

导语

众所周知,在 RLHF 中 OpenAI 使用自己的 PPO 算法来进行强化学习训练。本文作者认为,PPO 诚然是解决常见强化学习问题的好方法,但是,在 LLM 领域,却有独特三点特性,使得 PPO 可能并不是最适合 RLHF 的强化学习算法:

  1. Fast simulation。在传统的 RL 任务(游戏、机器人等)中,simulation 是一件很慢且成本很高的事情。但是在 LLM 中,simulation(rollout) 只需要将 prompt 输入给模型,然后将模型输出的答案输入给 reward model (or function),就可以得到 trajectory 和 reward 了,速度其实是比较快的,一般只需要几秒就可以生成一条样本;
  2. Deterministic environment。将 LLM RLHF 套到强化学习的语境下,环境的状态转移( s t → s t + 1 s_t\rightarrow s_{t+1} stst+1)其实就是将刚生成的 token 拼到 context 后面,是确定性的,而奖励来自 reward model (or function),也是确定性的。因此说,RLHF 的整个强化学习环境都是确定性的(当然 Agent,也就是 LLM 本身的采样生成是可以有随机性的);
  3. Trajectory-level reward。在 RLHF 中,奖励一般是 trajectory-level 的,即在完整回答生成后,才给出一个 reward,而中间过程,都是没有 reward 的。也正是因为这一点,作者认为 PPO 中的 value model 以及对应的 TD 残差学习,可能是不太合适的。

基于 LLM 的这三点性质,作者认为 RLHF 训练中的其实不需要 value model 即可简单高效地计算期望回报,并提出了 ReMax 方法。

在这里插入图片描述

从 REINFORCE 到 ReMax

作者对 RLHF 的改进没有从 PPO 出发,而是考虑了一种更早期的策略梯度类方法:REINFORCE。REINFORCE 是一种古老但经典的策略梯度算法,其形式我们在之前已经推导过,模型更新的策略梯度为:
∇ θ J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t ) R ( τ ) ] \nabla_\theta J(\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|s_t) R(\tau)\right] \notag \\ θJ(θ)=Eτπθ[t=0Tθlogπθ(atst)R(τ)]
其中 τ \tau τ 是一条轨迹(trajectory)样本,由 policy π θ \pi_\theta πθ 采样而来。

在 LLM RLHF 语境下,REINFORCE 算法的梯度可以写为:
E q ∼ D , o ∼ π θ [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ q , o < t ) R ( q , o 1 : T ) ] \mathbb{E}_{q\sim D,o\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|q,o_{<t})R(q,o_{1:T})\right] \notag \\ EqD,oπθ[t=0Tθlogπθ(atq,o<t)R(q,o1:T)]
其中 q q q 是输入给模型的问题,采样自数据集 D D D o t o_t ot 表示模型输出的第 t t t 个 token, R ( q , o 1 : T ) R(q,o_{1:T}) R(q,o1:T) 表示 reward model 对模型的完整回答给出的 sample level 的打分。

REINFORCE 算法直接根据一条轨迹(模型回答) 的累积回报来计算梯度,不用训练和推理额外的 value model,训练开销更小。作者认为,简单的 REINORCE 算法比较契合上面提到的 LLM 的三个特性,反而可能更加适合 RLHF。然而,REINFORCE 算法实际上也有自己的问题,那就是方差太大。

REINFORCE 算法的方差大,是早就众所周知的问题。在传统强化学习中,REINFORCE 算法中 R ( τ ) R(\tau) R(τ) 的方差来自两方面,一是 environment 的随机性,二是 policy 本身的随机性,REINFORCE 中,我们需要采样 T T T 步,直到一次 rollout 结束,才能得到最终的奖励 ,在这个过程中,来自上述两种随机性带来的采样方差不断积累,导致最终的累积奖励 R ( τ ) R(\tau) R(τ) 的方差非常大、

传统强化学习中,解决 REINFORCE 算法方差大的问题,一个比较经典的方案是在累积回报的基础上减掉一个 baseline。
E q ∼ D , o ∼ π θ [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ q , o < t ) [ R ( q , o 1 : T ) − b θ ( q ) ] ] \mathbb{E}_{q\sim D,o\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|q,o_{<t})[R(q,o_{1:T})\textcolor{red}{-b_\theta(q)}]\right] \notag \\ EqD,oπθ[t=0Tθlogπθ(atq,o<t)[R(q,o1:T)bθ(q)]]
ReMax 也是引入了一个 baseline 来减小 RLHF 训练中的方差。具体来说,ReMax 选用的 baseline 是 policy LLM 对于输入问题 q q q,以 greedy 的策略进行采样(实操中,取 do_sample=False 即可)得到贪婪采样结果的 reward 打分,即取:
b θ ( q ) = R ( q , o ˉ 1 : T ) , o ˉ t ∈ arg ⁡ max ⁡ π θ ( ⋅ ∣ q , o < t ) b_\theta(q)=R(q,\bar{o}_{1:T}),\quad \bar{o}_{t}\in\arg\max\pi_\theta(\cdot|q,o_{<t}) \notag \\ bθ(q)=R(q,oˉ1:T),oˉtargmaxπθ(q,o<t)
REINFORCE 的基础上使用 argmax 计算 baseline,故称为 ReMax。

理论上,作者证明了 ReMax 仍是一个无偏估计,并且方差会更小。直观上来理解,ReMax 是取贪婪解码的结果,也是没有随机性的,因此方差确实会更小。原文的实验结果也展示出 ReMax 比 PPO 训练开销更小,效果也要好一些。

总结

ReMax 是很早就提出 RLHF 不用 value model 的工作之一。论文首先分析了 RLHF 相较于传统 RL 的独特性,然后在经典的 REINFORCE 算法的基础上,引入贪婪采样结果的 reward 作为 baseline,来降低方差。方法简单有效,理论分析也比较充分,是一篇很不错的工作。有一点 concern 在于,(可能是受限于资源)ReMax 的实验是做在 1.3B/7B 的 “小” 模型上的,而在上百 B 的大模型上,Simulation 还有没有那么 Fast,多一次 Simulation(用于算 baseline) 与 value model 的开销相比,是否还有那么大的优势?

http://www.xdnf.cn/news/309817.html

相关文章:

  • Java并发编程-锁(一)
  • miniqtm 模拟账号和实盘账号登陆对数据获取有什么影响
  • vLLM 推理 Qwen2.5-VL-7B 图像
  • 机器人系统设置
  • 小型纯电动汽车轮毂电机及大角度转向系统的数字化设计
  • 卷积神经网络基础(五)
  • 大语言模型(LLM)领域,有几项显著的进展和技术突破
  • JavaSE核心知识点01基础语法01-04(数组)
  • RPM打包格式spec文件设计原理与关键特性说明
  • Python cv2滤波与模糊处理:从原理到实战
  • Matlab/Simulink的一些功能用法笔记(4)
  • AI教你学VUE——Deepseek版
  • 从入门到登峰-嵌入式Tracker定位算法全景之旅 Part 8 |产品化与运维:批量标定、误差监控、OTA 升级与安全防护
  • CSS Border 三角形阴影与多重边框的制作
  • Beetle 树莓派RP2350 - 桌面时钟摆件
  • 内存种类详解
  • tinyrenderer笔记(Shadow Mapping)
  • 方案精读:2024版基于华为IPD与质量管理体系融合的研发质量管理【附全文阅读】
  • AOAAO:算术优化算法与Aquila Optimizer的混合算法
  • langchain4j整合springboot
  • OpenCV的floodFill(漫水填充)分割
  • 静态NAT
  • C++23 新利器:深入解析栈踪迹库 (P0881R7)
  • HTTP协议网络读卡器通讯报文
  • 无法解析导入“pybulletgym”
  • C# System.Text.Json实现高效JSON序列化与反序列化
  • 基于Java多线程实现简单图片下载
  • SLAM算法工程师面经大全:2025年面试真题解析与实战指南
  • 美信监控易:全栈式自主可控的底层架构优势
  • 使用 Poco C++ 库构建轻量级 HTTP 服务器