当前位置: 首页 > news >正文

DFT:从RL的视角修正SFT损失的权重

DFT:从RL的视角修正SFT损失的权重

TL; DR:作者借助重要性采样,从 RL 的视角推导 SFT 的梯度形式,比较二者的差异,并引入一个动态权重来对 SFT 损失进行修正,得到了 DFT。

从 RL 的视角来看 SFT 的梯度形式

我们首先从 RL 的视角来推导一下 SFT 的梯度形式。记 D={(x,y∗)}\mathcal{D}=\{(x,y^*)\}D={(x,y)} 为 SFT 数据集,其中 xxx 是 query,y∗y^*y 是人类标注的 response。SFT 的训练目标可以写作下面这个个句子级别的交叉熵损失:
LSFT(θ)=E(x,y∗)∼D[−log⁡πθ(y∗∣x)] \mathcal{L}_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[-\log\pi_\theta(y^*|x)] \notag \\ LSFT(θ)=E(x,y)D[logπθ(yx)]
其梯度为:
∇θLSFT(θ)E(x,y∗)∼D[−∇θlog⁡πθ(y∗∣x)](1) \nabla_\theta\mathcal{L}_\text{SFT}(\theta)\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[-\nabla_\theta\log\pi_\theta(y^*|x)] \tag{1} \\ θLSFT(θ)E(x,y)D[θlogπθ(yx)](1)
而在强化学习中,记 yyy 是模型 πθ(⋅∣x)\pi_\theta(\cdot|x)πθ(x) 对 query xxx 的 response,r(x,y)∈Rr(x,y)\in\mathbb{R}r(x,y)R 为奖励函数。强化学习的目标函数为:
J(θ)=Ex∼Dx,y∼πθ(⋅∣x)[r(x,y)] J(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}[r(x,y)] \notag \\ J(θ)=ExDx,yπθ(x)[r(x,y)]
其对应的句子级别的策略梯度为:
∇θJ(θ)=Ex∼Dx,y∼πθ(⋅∣x)[∇θlog⁡πθ(y∣x)r(x,y)](2) \nabla_\theta J(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}[\nabla_\theta\log\pi_\theta(y|x)r(x,y)] \tag{2} \\ θJ(θ)=ExDx,yπθ(x)[θlogπθ(yx)r(x,y)](2)
SFT 是在一个固定的标注数据上进行训练,从 RL 的角度来看,是一种 off-policy 的算法。接下来,为了统一 SFT 和 RL 的梯度形式,我们使用重要性采样(Importance Sampling),将标注数据分布的期望转换为模型分布的期望,从而将 off-policy 的 SFT 算法(公式 1)转换为 on-policy 的算法:
∇θLSFT(θ)=Ex∼Dx,y∗∼D[−∇θlog⁡πθ(y∗∣x)]=Ex∼Dx,y∼πθ(⋅∣x)1[y=y∗]πθ(y∣x)[−∇θlog⁡πθ(y∣x)](3) \nabla_\theta\mathcal{L}_\text{SFT}(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,\textcolor{red}{y^*\sim\mathcal{D}}}[-\nabla_\theta\log\pi_\theta(y^*|x)]=\mathbb{E}_{x\sim\mathcal{D}_x,\textcolor{red}{y\sim\pi_\theta(\cdot|x)}}\textcolor{red}{\frac{\mathbf{1}[y=y^*]}{\pi_\theta(y|x)}}[-\nabla_\theta\log\pi_\theta(y|x)] \tag{3} \\ θLSFT(θ)=ExDx,yD[θlogπθ(yx)]=ExDx,yπθ(x)πθ(yx)1[y=y][θlogπθ(yx)](3)
将引入的重要性采样修正项拆开定义为两个辅助变量:
w(y∣x)=1πθ(y∣x),r(x,y)=1[y=y∗] w(y|x)=\frac{1}{\pi_\theta(y|x)},\quad r(x,y)=\mathbf{1}[y=y^*] \notag \\ w(yx)=πθ(yx)1,r(x,y)=1[y=y]
从而公式 3 可以写为:
∇θLSFT(θ)=Ex∼Dx,y∼πθ(⋅∣x)w(y∣x)[−∇θlog⁡πθ(y∣x)r(x,y)] \nabla_\theta\mathcal{L}_\text{SFT}(\theta)=\mathbb{E}_{x\sim\mathcal{D}_x,y\sim\pi_\theta(\cdot|x)}\textcolor{blue}{w(y|x)}[-\nabla_\theta\log\pi_\theta(y|x)\textcolor{blue}{r(x,y)}] \notag \\ θLSFT(θ)=ExDx,yπθ(x)w(yx)[θlogπθ(yx)r(x,y)]
可以看到,这个形式与强化学习策略梯度的形式(公式 2)很接近了。这里的 w(y∣x)w(y|x)w(yx) 可以视为一个根据模型对标注 response 的生成概率来动态调整的权重,r(x,y)r(x,y)r(x,y) 可以视为一个奖励函数,在模型 response 与标注 response 完全一致时,奖励值为 1,其他奖励值全为 0。

现在,我们来思考一下这个形式的转换构造出来的这两个辅助变量的含义。

对于奖励 r(x,y)=1[y=y∗]r(x,y)=\mathbf{1}[y=y*]r(x,y)=1[y=y],即只有在模型的 response 与标注 response 完全相同时,奖励值为 1,其他时候奖励都为 0。显然这个函数作为奖励函数是稀疏的,而且这种稀疏性在 SFT 这种给定标注数据的训练形式下是不可避免的。这也正是 RL 的核心:exploration,所要解决的问题。在本文中,由于作者要改进的还是 SFT,因此没有对这个奖励函数作出改动。

对于权重 w(y∣x)=1/πθ(y∣x)w(y|x)=1/\pi_\theta(y|x)w(yx)=1/πθ(yx),即当模型对标注 response 给出的生成概率比较低时,www 会比较大,作者认为从 RL 的视角来看,这会导致方差很大。并且由于 SFT 形式奖励的稀疏性,这个方差大的问题会被进一步放大。稀疏的奖励会导致模型过拟合到标注 response 上。这个权重 www 是作者要改进的问题。

DFT

分析之后,作者最终将 SFT 表现不如 RL 好的原因归结在了权重 www 上。对于权重 www,解决方案就很简单了,直接乘一个 1/w1/w1/w 给它抵消掉就好了。即:
LDFT(θ)=E(x,y∗)∼D[sg(1/w)log⁡πθ(yt∗∣x)]=E(x,y∗)∼D[sg(πθ(y∗∣x))log⁡πθ(yt∗∣x)] \begin{aligned} \mathcal{L}_\text{DFT}(\theta)&=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[\textcolor{red}{\text{sg}(1/w)}\log \pi_\theta(y_t^*|x)] \\ &=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}[\textcolor{red}{\text{sg}(\pi_\theta(y^*|x))}\log \pi_\theta(y_t^*|x)] \\ \end{aligned} \notag \\ LDFT(θ)=E(x,y)D[sg(1/w)logπθ(ytx)]=E(x,y)D[sg(πθ(yx))logπθ(ytx)]
另外为了避免在整个采样轨迹上计算重要性权重导致的数值不稳定性,作者最终采用了 token level 的损失形式:
LDFT(θ)=E(x,y∗)∼D[−∑t=1∣y∗∣sg(πθ(yt∗∣y<t∗,x))log⁡(πθ(yt∗∣y<t∗,x)] \mathcal{L}_\text{DFT}(\theta)=\mathbb{E}_{(x,y^*)\sim\mathcal{D}}\left[-\sum_{t=1}^{|y^*|}\text{sg}(\pi_\theta(y_t^*|y_{<t}^*,x))\log(\pi_\theta(y_t^*|y_{<t}^*,x)\right] \notag \\ LDFT(θ)=E(x,y)Dt=1ysg(πθ(yty<t,x))log(πθ(yty<t,x)
最终形式上

  • 从 SFT 的视角来看,DFT 是在 SFT 的基础上乘了一个动态修正权重 πθ\pi_\thetaπθ,即在模型比较确信的样本上给较高权重。这与 focal loss 对分类不准的样本给高权重的思路正好是相反的,因为在大模型大数据时代,这些模型置信度不高的样本,很有可能是噪声数据。作者认为这反映了在大模型时代,模型的泛化能力更为重要,过拟合的问题比欠拟合更严重。
  • 从 RL 的视角来看,DFT(在梯度形式上)与一般的 RL 就已经一致了,只是(但其实这一点非常重要)奖励函数非常稀疏

DFT 实现起来也非常简单,实际上就是给损失乘上了一个动态权重,代码中只需要添加一行:

loss = loss * torch.softmax(shift_logits, dim=-1).gather(1, shift_labels.unsqueeze(-1)).squeeze(-1).detach()

总结

DFT 最近讨论度很高。作者借助重要性采样,从 RL 的视角推导了 SFT 的梯度形式,发现 SFT 与 RL 在梯度形式上的差异在于一个权重 w=1/πθw=1/\pi_\thetaw=1/πθ,进而提出给 SFT 的损失函数乘上一个对其进行修正的动态权重 πθ\pi_\thetaπθ,得到了 DFT。

但是从作者推导出的形式来看,个人感觉 SFT 泛化性较弱的原因应该反而主要来自于 SFT 本身的奖励函数的稀疏性。另外还有一点,作者在推导时用到了重要性采样进行分布变换,但实际中重要性采样并不是一定稳定的,需要保证变换前后的分布差异不能太大,这一点原文中似乎也没有讨论。

http://www.xdnf.cn/news/1451701.html

相关文章:

  • 【高分论文密码】大尺度空间模拟预测与数字制图
  • Django事务
  • Leetcode 240. 搜索二维矩阵 II 矩阵 / 二分
  • 垃圾回收,几种GC算法及GC机制
  • 数据库中事务、指令、写法解读
  • 搭建基于 Solon AI 的 Streamable MCP 服务并部署至阿里云百炼
  • 【多线程初阶】线程安全问题 死锁产生 何如避免死锁
  • 前端vue常见标签属性及作用解析
  • 零售消费企业的数字化增长实践,2025新版下载
  • 在 Debian 系统上清理缓存的方式和具体操作方法
  • Grafana - 监控磁盘使用率Variables使用
  • 卫星互联网安全风险及关键技术探索
  • 【深度学习】P1 引言(待完成)
  • Conda 常用命令大全
  • Axure RP 9 Mac 交互原型设计
  • iPhone17再爆猛料?苹果2025秋季发布会亮点抢先看
  • Jenkins调用ansible部署lnmp平台
  • 阿里云-基于通义灵码实现高效 AI 编码 | 1 | 在 Visual Studio Code 中安装和使用灵码
  • Redis vs Memcached vs MongoDB:深入对比与选型指南
  • AE(自动编码器)技术解析
  • Photoshop - Photoshop 触摸功能
  • 2025高教社杯国赛数学建模选题建议+初步分析
  • Java Web :技术根基与产业实践的多维耦合
  • CSS 渐变边框
  • tensorflow常用使用场景
  • 开源免费工具,使用 Copicseal 批量添加照片参数水印教程
  • 打造大师级渲染:10个高效工作流技巧,质效双升
  • VisionPro工业相机 硬触发操作前以及Vs实现
  • iOS 抓包工具怎么选?开发者的实战经验与选择指南
  • WEB3的资料——免费开放