DFT:从RL的视角修正SFT损失的权重
DFT:从RL的视角修正SFT损失的权重
TL; DR:作者借助重要性采样,从 RL 的视角推导 SFT 的梯度形式,比较二者的差异,并引入一个动态权重来对 SFT 损失进行修正,得到了 DFT。
从 RL 的视角来看 SFT 的梯度形式
我们首先从 RL 的视角来推导一下 SFT 的梯度形式。记 D={(x,y∗)}\mathcal{D}=\{(x,y^*)\}D={(x,y∗)} 为 SFT 数据集,其中 xxx 是 query,y∗y^*y∗ 是人类标注的 response。SFT 的训练目标可以写作下面这个个句子级别的交叉熵损失:
LSFT(θ)=E(x,y∗)∼D[−logπθ(y∗∣x)]
\mathcal{L}_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[-\log\pi_\theta(y^*|x)] \notag \\
LSFT(θ)=E(x,y∗)∼D[−logπθ(y∗∣x)]
其梯度为:
∇θLSFT(θ)E(x,y∗)∼D[−∇θlogπθ(y∗∣x)](1)
\nabla_\theta\mathcal{L}_\text{SFT}(\theta)\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[-\nabla_\theta\log\pi_\theta(y^*|x)] \tag{1} \\
∇θLSFT(θ)E(x,y∗)∼D[−∇θlogπθ(y∗∣x)](1)
而在强化学习中,记 yyy 是模型 πθ(⋅∣x)\pi_\theta(\cdot|x)πθ(⋅∣x) 对 query xxx 的 response,r(x,y)∈Rr(x,y)\in\mathbb{R}r(x,y)∈R 为奖励函数。强化学习的目标函数为:
J(θ)=Ex∼Dx,y∼πθ(⋅∣x)[r(x,y)]
J(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}[r(x,y)] \notag \\
J(θ)=Ex∼Dx,y∼πθ(⋅∣x)[r(x,y)]
其对应的句子级别的策略梯度为:
∇θJ(θ)=Ex∼Dx,y∼πθ(⋅∣x)[∇θlogπθ(y∣x)r(x,y)](2)
\nabla_\theta J(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}[\nabla_\theta\log\pi_\theta(y|x)r(x,y)] \tag{2} \\
∇θJ(θ)=Ex∼Dx,y∼πθ(⋅∣x)[∇θlogπθ(y∣x)r(x,y)](2)
SFT 是在一个固定的标注数据上进行训练,从 RL 的角度来看,是一种 off-policy 的算法。接下来,为了统一 SFT 和 RL 的梯度形式,我们使用重要性采样(Importance Sampling),将标注数据分布的期望转换为模型分布的期望,从而将 off-policy 的 SFT 算法(公式 1)转换为 on-policy 的算法:
∇θLSFT(θ)=Ex∼Dx,y∗∼D[−∇θlogπθ(y∗∣x)]=Ex∼Dx,y∼πθ(⋅∣x)1[y=y∗]πθ(y∣x)[−∇θlogπθ(y∣x)](3)
\nabla_\theta\mathcal{L}_\text{SFT}(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,\textcolor{red}{y^*\sim\mathcal{D}}}[-\nabla_\theta\log\pi_\theta(y^*|x)]=\mathbb{E}_{x\sim\mathcal{D}_x,\textcolor{red}{y\sim\pi_\theta(\cdot|x)}}\textcolor{red}{\frac{\mathbf{1}[y=y^*]}{\pi_\theta(y|x)}}[-\nabla_\theta\log\pi_\theta(y|x)] \tag{3} \\
∇θLSFT(θ)=Ex∼Dx,y∗∼D[−∇θlogπθ(y∗∣x)]=Ex∼Dx,y∼πθ(⋅∣x)πθ(y∣x)1[y=y∗][−∇θlogπθ(y∣x)](3)
将引入的重要性采样修正项拆开定义为两个辅助变量:
w(y∣x)=1πθ(y∣x),r(x,y)=1[y=y∗]
w(y|x)=\frac{1}{\pi_\theta(y|x)},\quad r(x,y)=\mathbf{1}[y=y^*] \notag \\
w(y∣x)=πθ(y∣x)1,r(x,y)=1[y=y∗]
从而公式 3 可以写为:
∇θLSFT(θ)=Ex∼Dx,y∼πθ(⋅∣x)w(y∣x)[−∇θlogπθ(y∣x)r(x,y)]
\nabla_\theta\mathcal{L}_\text{SFT}(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}\textcolor{blue}{w(y|x)}[-\nabla_\theta\log\pi_\theta(y|x)\textcolor{blue}{r(x,y)}] \notag \\
∇θLSFT(θ)=Ex∼Dx,y∼πθ(⋅∣x)w(y∣x)[−∇θlogπθ(y∣x)r(x,y)]
可以看到,这个形式与强化学习策略梯度的形式(公式 2)很接近了。这里的 w(y∣x)w(y|x)w(y∣x) 可以视为一个根据模型对标注 response 的生成概率来动态调整的权重,r(x,y)r(x,y)r(x,y) 可以视为一个奖励函数,在模型 response 与标注 response 完全一致时,奖励值为 1,其他奖励值全为 0。
现在,我们来思考一下这个形式的转换构造出来的这两个辅助变量的含义。
对于奖励 r(x,y)=1[y=y∗]r(x,y)=\mathbf{1}[y=y*]r(x,y)=1[y=y∗],即只有在模型的 response 与标注 response 完全相同时,奖励值为 1,其他时候奖励都为 0。显然这个函数作为奖励函数是稀疏的,而且这种稀疏性在 SFT 这种给定标注数据的训练形式下是不可避免的。这也正是 RL 的核心:exploration,所要解决的问题。在本文中,由于作者要改进的还是 SFT,因此没有对这个奖励函数作出改动。
对于权重 w(y∣x)=1/πθ(y∣x)w(y|x)=1/\pi_\theta(y|x)w(y∣x)=1/πθ(y∣x),即当模型对标注 response 给出的生成概率比较低时,www 会比较大,作者认为从 RL 的视角来看,这会导致方差很大。并且由于 SFT 形式奖励的稀疏性,这个方差大的问题会被进一步放大。稀疏的奖励会导致模型过拟合到标注 response 上。这个权重 www 是作者要改进的问题。
DFT
分析之后,作者最终将 SFT 表现不如 RL 好的原因归结在了权重 www 上。对于权重 www,解决方案就很简单了,直接乘一个 1/w1/w1/w 给它抵消掉就好了。即:
LDFT(θ)=E(x,y∗)∼D[sg(1/w)logπθ(yt∗∣x)]=E(x,y∗)∼D[sg(πθ(y∗∣x))logπθ(yt∗∣x)]
\begin{aligned}
\mathcal{L}_\text{DFT}(\theta)&=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[\textcolor{red}{\text{sg}(1/w)}\log \pi_\theta(y_t^*|x)] \\
&=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[\textcolor{red}{\text{sg}(\pi_\theta(y^*|x))}\log \pi_\theta(y_t^*|x)] \\
\end{aligned} \notag \\
LDFT(θ)=E(x,y∗)∼D[sg(1/w)logπθ(yt∗∣x)]=E(x,y∗)∼D[sg(πθ(y∗∣x))logπθ(yt∗∣x)]
另外为了避免在整个采样轨迹上计算重要性权重导致的数值不稳定性,作者最终采用了 token level 的损失形式:
LDFT(θ)=E(x,y∗)∼D[−∑t=1∣y∗∣sg(πθ(yt∗∣y<t∗,x))log(πθ(yt∗∣y<t∗,x)]
\mathcal{L}_\text{DFT}(\theta)=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}\left[-\sum_{t=1}^{|y^*|}\text{sg}(\pi_\theta(y_t^*|y_{<t}^*,x))\log(\pi_\theta(y_t^*|y_{<t}^*,x)\right] \notag \\
LDFT(θ)=E(x,y∗)∼D−t=1∑∣y∗∣sg(πθ(yt∗∣y<t∗,x))log(πθ(yt∗∣y<t∗,x)
最终形式上
- 从 SFT 的视角来看,DFT 是在 SFT 的基础上乘了一个动态修正权重 πθ\pi_\thetaπθ,即在模型比较确信的样本上给较高权重。这与 focal loss 对分类不准的样本给高权重的思路正好是相反的,因为在大模型大数据时代,这些模型置信度不高的样本,很有可能是噪声数据。作者认为这反映了在大模型时代,模型的泛化能力更为重要,过拟合的问题比欠拟合更严重。
- 从 RL 的视角来看,DFT(在梯度形式上)与一般的 RL 就已经一致了,只是(但其实这一点非常重要)奖励函数非常稀疏
DFT 实现起来也非常简单,实际上就是给损失乘上了一个动态权重,代码中只需要添加一行:
loss = loss * torch.softmax(shift_logits, dim=-1).gather(1, shift_labels.unsqueeze(-1)).squeeze(-1).detach()
总结
DFT 最近讨论度很高。作者借助重要性采样,从 RL 的视角推导了 SFT 的梯度形式,发现 SFT 与 RL 在梯度形式上的差异在于一个权重 w=1/πθw=1/\pi_\thetaw=1/πθ,进而提出给 SFT 的损失函数乘上一个对其进行修正的动态权重 πθ\pi_\thetaπθ,得到了 DFT。
但是从作者推导出的形式来看,个人感觉 SFT 泛化性较弱的原因应该反而主要来自于 SFT 本身的奖励函数的稀疏性。另外还有一点,作者在推导时用到了重要性采样进行分布变换,但实际中重要性采样并不是一定稳定的,需要保证变换前后的分布差异不能太大,这一点原文中似乎也没有讨论。