当前位置: 首页 > java >正文

知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化

知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化

flyfish

代码实践

On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes

目标函数(Objective Function) 是衡量模型预测结果与真实结果之间差异的函数,其核心作用是为模型的参数优化提供 “指导信号”—— 通过最小化(或最大化)目标函数的值,让模型逐渐学习到更优的参数,从而提升预测性能。

深度学习模型的训练本质是 “参数优化”:模型通过输入数据生成预测结果后,目标函数会计算预测值与真实标签(或期望输出)的 “差距”,得到一个量化的 “损失值”。这个损失值越大,说明模型当前的预测效果越差;反之则越好。
优化算法(如梯度下降)会基于目标函数的梯度信息,调整模型参数(如权重、偏置),最终使目标函数的值达到最小(或最大,视任务而定),此时模型的预测结果与真实结果最接近。

其他名字

损失函数(Loss Function):很多时候人们会直接说 “损失函数” 来指代目标函数,尤其在单样本或简化场景中。
代价函数(Cost Function):和损失函数类似,在不少教材或论文里,这两个词和目标函数几乎同义,只是 “代价” 更强调 “为错误付出的成本”。
优化目标(Optimization Objective):因为目标函数的核心是被模型 “优化” 的对象(比如最小化它),所以也常被称为 “优化目标”。
准则函数(Criterion Function):“准则” 就是 “判断标准”,这个词更强调它是衡量模型好坏的 “标准”,和目标函数含义一致。
这些名字都是用来指导模型优化的函数,只是在不同语境里习惯用不同的词,不用太纠结,知道它们说的是一回事儿就行

演化链

  1. SFT / MLE(one-hot“老师”)
    → 2) 监督 KD(固定数据 + 前向 KL,对齐软标签)
    → 3) On-Policy KD(在学生自己会走到的状态上学,消除分布失配)
    → 4) 换散度:JSD(β)/反向 KL(容量受限时更模式寻求,少幻觉)
    → 5) GKDλ\lambdaλ 控制样本来源,DDD 自由可选,统一一切)
    → 6) RL + On-Policy GKD(奖励最大化 + 蒸馏约束,同训同收)。

0. 基础:自回归分解与“逐 token”散度

自回归 LM 满足 p(y∣x)=∏n=1Lyp(yn∣y<n,x)p(y|x)=\prod_{n=1}^{L_y}p(y_n|y_{<n},x)p(yx)=n=1Lyp(yny<n,x)。因此一切“序列级”的目标,都能拆成逐 token的和。论文把“老师 pTp_TpT”与“学生 pSp_SpS”在序列 yyy 上的分布差异定义为(式 (2)):

D(pT⁣∥pSθ)(y∣x)=1Ly∑n=1LyD⁣(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x)).D(p_T\!\parallel p_S^\theta)(y|x)=\frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot|y_{<n},x)\parallel p_S^\theta(\cdot|y_{<n},x)\right). D(pTpSθ)(yx)=Ly1n=1LyD(pT(y<n,x)pSθ(y<n,x)).

这一步把任何“散度 DDD”都落到了token 级,后续所有方法只需选择:用什么 DDD,在什么数据上评价它。

1. 监督式微调(SFT):极大似然是前向 KL 的特例

只有人工标注 (X,Y)(X,Y)(X,Y)没有老师时,最简单是最小化负对数似然:

LSFT(θ)=E(x,y)∼(X,Y)[−log⁡pSθ(y∣x)].L_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)}[-\log p_S^\theta(y|x)]. LSFT(θ)=E(x,y)(X,Y)[logpSθ(yx)].

它等价于把每个目标 token 看成 one-hot“老师”,即在式 (2) 中令 D=DKLD= D_{\mathrm{KL}}D=DKLpTp_TpT 为真分布的“δ-分布”。

梯度形态(单步 token)
对学生 softmax 的对数几率 zzz 来说,前向 KL / 交叉熵的梯度是 ∇z=pS−pT\nabla_z = p_S - p_Tz=pSpT。这说明 SFT/前向 KL 本质是让学生概率向目标分布对齐

局限:只在固定数据上学,推理时学生会走到自己没见过的前缀状态,产生训练-推理分布失配(exposure bias)。

2. 监督式 KD(Supervised KD):用“软标签”的前向 KL

有了老师 pTp_TpT(可给每个 token 的全分布),就把 SFT 的 one-hot 换成老师分布,得到(式 (3)):

LSD(θ)=E(x,y)∼(X,Y)[DKL⁣(pT∥pSθ)(y∣x)].L_{\text{SD}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)} \Big[D_{\mathrm{KL}}\!\big(p_T\parallel p_S^\theta\big)(y|x)\Big]. LSD(θ)=E(x,y)(X,Y)[DKL(pTpSθ)(yx)].

好处:利用“软标签”提供的暗知识(非目标 token 的相对概率)。梯度仍是 pS−pTp_S-p_TpSpT 的形态,但 pTp_TpT 不再是 one-hot。

局限:仍在固定序列上训练(可能是人工真值或老师生成的序列),分布失配依旧。

3. On-Policy KD:让学生在自己生成的序列上学

为解决分布失配,论文把“期望”从固定 (X,Y)(X,Y)(X,Y) 换成学生策略下的输出序列,得到(式 (4)):

LOD(θ)=Ex∼X[Ey∼pS(⋅∣x)[DKL(pT∥pSθ)(y∣x)]].L_{\text{OD}}(\theta)=\mathbb{E}_{x\sim X}\Big[\mathbb{E}_{y\sim p_S(\cdot|x)} \big[D_{\mathrm{KL}}(p_T\parallel p_S^\theta)(y|x)\big]\Big]. LOD(θ)=ExX[EypS(x)[DKL(pTpSθ)(yx)]].

关键实现细节不对学生的采样分布反传(只在内层 KL 里更新 θ\thetaθ),可避免 REINFORCE 式高方差,训练更稳定高效。直观地说:学生先按当前策略走一遍,把“走错的地方”交给老师打分,再按 KL 梯度把这些状态上的 logits 拉回去。

与模仿学习的联系:这相当于 DAgger 风格的在线收集+专家纠正

4. 选择更合适的“散度”DDD:从 KL 到广义 JSD

前向 KL 要求学生“覆盖老师的全部支持集”,容量不足时会把概率“摊薄”到老师几乎不选的 token 上,易造成幻觉;反向 KL 则更“择众”,只贴老师高概率 token,减少“离谱”但可能牺牲多样性。论文采用广义 JSD 在两者间连续插值(式 (1)):

DJSD(β)(P∥Q)=βDKL⁣(P∥βP+(1−β)Q)+(1−β)DKL⁣(Q∥βP+(1−β)Q).D_{\mathrm{JSD}(\beta)}(P\parallel Q)= \beta\, D_{\mathrm{KL}}\!\big(P\parallel \beta P+(1-\beta)Q\big)+ (1-\beta)\, D_{\mathrm{KL}}\!\big(Q\parallel \beta P+(1-\beta)Q\big). DJSD(β)(PQ)=βDKL(PβP+(1β)Q)+(1β)DKL(QβP+(1β)Q).

β→0\beta\to 0β0 时,1βDJSD(β)→DKL(P∥Q)\tfrac{1}{\beta}D_{\mathrm{JSD}(\beta)} \to D_{\mathrm{KL}}(P\parallel Q)β1DJSD(β)DKL(PQ)β→1\beta\to 1β1 时更接近反向 KL 的行为。这样就能按任务、温度、容量调节“覆盖 vs. 模式寻求”的折衷。

经验指引:不同任务/采样温度的最优散度不同;且很多实验里,**纯 on-policy(学生样本占比 100%)**的效果最好。

5. GKD:统一“在哪些序列上学”和“用什么散度学”

把“数据来源”(固定数据 vs 学生自采样)与“散度种类”(前向/反向 KL、JSD(β)…)统一起来,得到广义 KD(式 (GKD)):

⁣ ⁣ ⁣ ⁣LGKD(θ)=(1−λ)E(x,y)∼(X,Y)⁣ ⁣[D(pT∥pSθ)(y∣x)]+λEx∼XEy∼pS(⋅∣x)⁣ ⁣[D(pT∥pSθ)(y∣x)].\!\!\!\!L_{\text{GKD}}(\theta) =(1-\lambda)\,\mathbb{E}_{(x,y)\sim(X,Y)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big] +\lambda\,\mathbb{E}_{x\sim X}\mathbb{E}_{y\sim p_S(\cdot|x)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big]. LGKD(θ)=(1λ)E(x,y)(X,Y)[D(pTpSθ)(yx)]+λExXEypS(x)[D(pTpSθ)(yx)].

  • λ=0\lambda=0λ=0:退化为监督 KDλ=1\lambda=1λ=1纯 on-policy KD
  • DDD 可取前向/反向 KL 或 JSD(β\betaβ)
    实现上,对采样过程仍不反传。论文还给了Algorithm 1:每步按 λ\lambdaλ 抛硬币决定用哪类数据,再最小化该 batch 的散度。

一眼看懂的梯度形态(抽象)
忽略采样反传后,

∇θLGKD=E⁣[1Ly∑n∇θD⁣(pT(⋅)∥pSθ(⋅))⏟如前向 KL 时 ∝pS−pT],\nabla_\theta L_{\text{GKD}} = \mathbb{E}\!\left[\frac{1}{L_y}\sum_n \underbrace{\nabla_\theta D\!\big(p_T(\cdot)\parallel p_S^\theta(\cdot)\big)}_{\text{如前向 KL 时 }\;\propto\;p_S-p_T}\right], θLGKD=ELy1n如前向 KL  pSpTθD(pT()pSθ()),

唯一差别在于这个期望是对哪种序列分布取(由 λ\lambdaλ 和是否 on-policy 决定),以及**DDD** 的具体形式。

6. 再进一步:把 RL 目标和 On-Policy GKD 并列优化

很多真实目标(如事实一致性)是不可导/非似然的。论文把策略梯度的奖励项蒸馏正则项并列,给出(式 (5)):

Ex∼X[(1−α)Ey∼pSθ[r(y)]−αEy∼pS(⋅∣x)D(pT∥pSθ)(y∣x)].\mathbb{E}_{x\sim X}\Big[(1-\alpha)\,\mathbb{E}_{y\sim p_S^\theta}[r(y)] -\alpha\,\mathbb{E}_{y\sim p_S(\cdot|x)}D(p_T\parallel p_S^\theta)(y|x)\Big]. ExX[(1α)EypSθ[r(y)]αEypS(x)D(pTpSθ)(yx)].

  • 第一项:标准 REINFORCE/策略梯度的RL 目标
  • 第二项:on-policy 蒸馏正则,把策略往老师靠,以防 RL 走偏;
  • α\alphaα 控制 RL 与蒸馏的权衡α=1\alpha=1α=1 退化为仅蒸馏)。这与 RLHF 里常见的“KL 正则”相似,但这里是向老师而不是向初始策略收缩。
http://www.xdnf.cn/news/18510.html

相关文章:

  • 【Cmake】Cmake概览
  • 使用GMail API 发送邮箱
  • OpenSCA开源社区每日安全漏洞及投毒情报资讯|21th Aug. , 2025
  • 前端github-workflows部署腾讯云轻量服务器
  • 实用R语言机器学习指南:从数据预处理到模型实战(附配套学习资源)
  • docker 查看容器 docker 筛选容器
  • 循环神经网络实战:GRU 对比 LSTM 的中文情感分析(三)
  • Flask数据库迁移实战指南
  • LeetCode100-76最小覆盖子串
  • 数据库备份sql文件过大,phpAdmin无法执行Sql
  • Python递归下降解析器深度解析:从原理到工程实践
  • 异常值检测:孤立森林模型(IsolationForest)总结
  • Flowise 任意文件上传漏洞 含Flowise Docker安装、漏洞复现(CVE-2025-26319)
  • 如何使用 DeepSeek 助力工作:全面指南​
  • AWS OpenSearch 是什么
  • ROS2下YOLO+Moveit+PCL机械臂自主避障抓取方案
  • 如何理解AP服务发现协议中“如果某项服务需要被配置为可通过多个不同的网络接口进行访问,则应为每个网络接口使用一个独立的客户端服务实例”?
  • Unreal Engine APawn 与 ACharacter 比较
  • 停车场道闸的常见形式
  • Docker的安装
  • 什么是数据分类分级?数据分类分级技术实现路径及产品推荐
  • 逆向代码笔记
  • centos7安装oracle19c流程(自用)
  • 全面解析 `strchr` 字符串查找函数
  • 闲置笔记本链接硬盘盒充当Windows NAS 网易UU远程助力数据读取和处理
  • vivo招AI架构专家(AI Agent方向)
  • 云原生(Cloud Native)技术概述
  • 密码管理中硬编码密码
  • react的基本使用
  • 【学习记录】structuredClone,URLSearchParams,groupBy