ReMax:将贪婪采样的 reward 作为 baseline
ReMax:将贪婪采样的 reward 作为 baseline
TL; DR:在经典的 REINFORCE 算法的基础上,取每个问题贪婪采样结果的 reward 作为 baseline,以降低方差。
导语
众所周知,在 RLHF 中 OpenAI 使用自己的 PPO 算法来进行强化学习训练。本文作者认为,PPO 诚然是解决常见强化学习问题的好方法,但是,在 LLM 领域,却有独特三点特性,使得 PPO 可能并不是最适合 RLHF 的强化学习算法:
- Fast simulation。在传统的 RL 任务(游戏、机器人等)中,simulation 是一件很慢且成本很高的事情。但是在 LLM 中,simulation(rollout) 只需要将 prompt 输入给模型,然后将模型输出的答案输入给 reward model (or function),就可以得到 trajectory 和 reward 了,速度其实是比较快的,一般只需要几秒就可以生成一条样本;
- Deterministic environment。将 LLM RLHF 套到强化学习的语境下,环境的状态转移( s t → s t + 1 s_t\rightarrow s_{t+1} st→st+1)其实就是将刚生成的 token 拼到 context 后面,是确定性的,而奖励来自 reward model (or function),也是确定性的。因此说,RLHF 的整个强化学习环境都是确定性的(当然 Agent,也就是 LLM 本身的采样生成是可以有随机性的);
- Trajectory-level reward。在 RLHF 中,奖励一般是 trajectory-level 的,即在完整回答生成后,才给出一个 reward,而中间过程,都是没有 reward 的。也正是因为这一点,作者认为 PPO 中的 value model 以及对应的 TD 残差学习,可能是不太合适的。
基于 LLM 的这三点性质,作者认为 RLHF 训练中的其实不需要 value model 即可简单高效地计算期望回报,并提出了 ReMax 方法。
从 REINFORCE 到 ReMax
作者对 RLHF 的改进没有从 PPO 出发,而是考虑了一种更早期的策略梯度类方法:REINFORCE。REINFORCE 是一种古老但经典的策略梯度算法,其形式我们在之前已经推导过,模型更新的策略梯度为:
∇ θ J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t ) R ( τ ) ] \nabla_\theta J(\theta)=\mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|s_t) R(\tau)\right] \notag \\ ∇θJ(θ)=Eτ∼πθ[t=0∑T∇θlogπθ(at∣st)R(τ)]
其中 τ \tau τ 是一条轨迹(trajectory)样本,由 policy π θ \pi_\theta πθ 采样而来。
在 LLM RLHF 语境下,REINFORCE 算法的梯度可以写为:
E q ∼ D , o ∼ π θ [ ∑ t = 0 T ∇ θ log π θ ( a t ∣ q , o < t ) R ( q , o 1 : T ) ] \mathbb{E}_{q\sim D,o\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|q,o_{<t})R(q,o_{1:T})\right] \notag \\ Eq∼D,o∼πθ[t=0∑T∇θlogπθ(at∣q,o<t)R(q,o1:T)]
其中 q q q 是输入给模型的问题,采样自数据集 D D D, o t o_t ot 表示模型输出的第 t t t 个 token, R ( q , o 1 : T ) R(q,o_{1:T}) R(q,o1:T) 表示 reward model 对模型的完整回答给出的 sample level 的打分。
REINFORCE 算法直接根据一条轨迹(模型回答) 的累积回报来计算梯度,不用训练和推理额外的 value model,训练开销更小。作者认为,简单的 REINORCE 算法比较契合上面提到的 LLM 的三个特性,反而可能更加适合 RLHF。然而,REINFORCE 算法实际上也有自己的问题,那就是方差太大。
REINFORCE 算法的方差大,是早就众所周知的问题。在传统强化学习中,REINFORCE 算法中 R ( τ ) R(\tau) R(τ) 的方差来自两方面,一是 environment 的随机性,二是 policy 本身的随机性,REINFORCE 中,我们需要采样 T T T 步,直到一次 rollout 结束,才能得到最终的奖励 ,在这个过程中,来自上述两种随机性带来的采样方差不断积累,导致最终的累积奖励 R ( τ ) R(\tau) R(τ) 的方差非常大、
传统强化学习中,解决 REINFORCE 算法方差大的问题,一个比较经典的方案是在累积回报的基础上减掉一个 baseline。
E q ∼ D , o ∼ π θ [ ∑ t = 0 T ∇ θ log π θ ( a t ∣ q , o < t ) [ R ( q , o 1 : T ) − b θ ( q ) ] ] \mathbb{E}_{q\sim D,o\sim\pi_\theta}\left[\sum_{t=0}^T\nabla_\theta\log\pi_\theta(a_t|q,o_{<t})[R(q,o_{1:T})\textcolor{red}{-b_\theta(q)}]\right] \notag \\ Eq∼D,o∼πθ[t=0∑T∇θlogπθ(at∣q,o<t)[R(q,o1:T)−bθ(q)]]
ReMax 也是引入了一个 baseline 来减小 RLHF 训练中的方差。具体来说,ReMax 选用的 baseline 是 policy LLM 对于输入问题 q q q,以 greedy 的策略进行采样(实操中,取 do_sample=False
即可)得到贪婪采样结果的 reward 打分,即取:
b θ ( q ) = R ( q , o ˉ 1 : T ) , o ˉ t ∈ arg max π θ ( ⋅ ∣ q , o < t ) b_\theta(q)=R(q,\bar{o}_{1:T}),\quad \bar{o}_{t}\in\arg\max\pi_\theta(\cdot|q,o_{<t}) \notag \\ bθ(q)=R(q,oˉ1:T),oˉt∈argmaxπθ(⋅∣q,o<t)
在 REINFORCE 的基础上使用 argmax 计算 baseline,故称为 ReMax。
理论上,作者证明了 ReMax 仍是一个无偏估计,并且方差会更小。直观上来理解,ReMax 是取贪婪解码的结果,也是没有随机性的,因此方差确实会更小。原文的实验结果也展示出 ReMax 比 PPO 训练开销更小,效果也要好一些。
总结
ReMax 是很早就提出 RLHF 不用 value model 的工作之一。论文首先分析了 RLHF 相较于传统 RL 的独特性,然后在经典的 REINFORCE 算法的基础上,引入贪婪采样结果的 reward 作为 baseline,来降低方差。方法简单有效,理论分析也比较充分,是一篇很不错的工作。有一点 concern 在于,(可能是受限于资源)ReMax 的实验是做在 1.3B/7B 的 “小” 模型上的,而在上百 B 的大模型上,Simulation 还有没有那么 Fast,多一次 Simulation(用于算 baseline) 与 value model 的开销相比,是否还有那么大的优势?