知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化
知识蒸馏 Knowledge Distillation 论文 Generalized Knowledge Distillation (GKD) 目标函数的演化
flyfish
代码实践
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
目标函数(Objective Function) 是衡量模型预测结果与真实结果之间差异的函数,其核心作用是为模型的参数优化提供 “指导信号”—— 通过最小化(或最大化)目标函数的值,让模型逐渐学习到更优的参数,从而提升预测性能。
深度学习模型的训练本质是 “参数优化”:模型通过输入数据生成预测结果后,目标函数会计算预测值与真实标签(或期望输出)的 “差距”,得到一个量化的 “损失值”。这个损失值越大,说明模型当前的预测效果越差;反之则越好。
优化算法(如梯度下降)会基于目标函数的梯度信息,调整模型参数(如权重、偏置),最终使目标函数的值达到最小(或最大,视任务而定),此时模型的预测结果与真实结果最接近。
其他名字
损失函数(Loss Function):很多时候人们会直接说 “损失函数” 来指代目标函数,尤其在单样本或简化场景中。
代价函数(Cost Function):和损失函数类似,在不少教材或论文里,这两个词和目标函数几乎同义,只是 “代价” 更强调 “为错误付出的成本”。
优化目标(Optimization Objective):因为目标函数的核心是被模型 “优化” 的对象(比如最小化它),所以也常被称为 “优化目标”。
准则函数(Criterion Function):“准则” 就是 “判断标准”,这个词更强调它是衡量模型好坏的 “标准”,和目标函数含义一致。
这些名字都是用来指导模型优化的函数,只是在不同语境里习惯用不同的词,不用太纠结,知道它们说的是一回事儿就行
演化链
- SFT / MLE(one-hot“老师”)
→ 2) 监督 KD(固定数据 + 前向 KL,对齐软标签)
→ 3) On-Policy KD(在学生自己会走到的状态上学,消除分布失配)
→ 4) 换散度:JSD(β)/反向 KL(容量受限时更模式寻求,少幻觉)
→ 5) GKD(λ\lambdaλ 控制样本来源,DDD 自由可选,统一一切)
→ 6) RL + On-Policy GKD(奖励最大化 + 蒸馏约束,同训同收)。
0. 基础:自回归分解与“逐 token”散度
自回归 LM 满足 p(y∣x)=∏n=1Lyp(yn∣y<n,x)p(y|x)=\prod_{n=1}^{L_y}p(y_n|y_{<n},x)p(y∣x)=∏n=1Lyp(yn∣y<n,x)。因此一切“序列级”的目标,都能拆成逐 token的和。论文把“老师 pTp_TpT”与“学生 pSp_SpS”在序列 yyy 上的分布差异定义为(式 (2)):
D(pT∥pSθ)(y∣x)=1Ly∑n=1LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x)).D(p_T\!\parallel p_S^\theta)(y|x)=\frac{1}{L_y}\sum_{n=1}^{L_y} D\!\left(p_T(\cdot|y_{<n},x)\parallel p_S^\theta(\cdot|y_{<n},x)\right). D(pT∥pSθ)(y∣x)=Ly1n=1∑LyD(pT(⋅∣y<n,x)∥pSθ(⋅∣y<n,x)).
这一步把任何“散度 DDD”都落到了token 级,后续所有方法只需选择:用什么 DDD,在什么数据上评价它。
1. 监督式微调(SFT):极大似然是前向 KL 的特例
只有人工标注 (X,Y)(X,Y)(X,Y) 而没有老师时,最简单是最小化负对数似然:
LSFT(θ)=E(x,y)∼(X,Y)[−logpSθ(y∣x)].L_{\text{SFT}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)}[-\log p_S^\theta(y|x)]. LSFT(θ)=E(x,y)∼(X,Y)[−logpSθ(y∣x)].
它等价于把每个目标 token 看成 one-hot“老师”,即在式 (2) 中令 D=DKLD= D_{\mathrm{KL}}D=DKL 且 pTp_TpT 为真分布的“δ-分布”。
梯度形态(单步 token)
对学生 softmax 的对数几率 zzz 来说,前向 KL / 交叉熵的梯度是 ∇z=pS−pT\nabla_z = p_S - p_T∇z=pS−pT。这说明 SFT/前向 KL 本质是让学生概率向目标分布对齐。
局限:只在固定数据上学,推理时学生会走到自己没见过的前缀状态,产生训练-推理分布失配(exposure bias)。
2. 监督式 KD(Supervised KD):用“软标签”的前向 KL
有了老师 pTp_TpT(可给每个 token 的全分布),就把 SFT 的 one-hot 换成老师分布,得到(式 (3)):
LSD(θ)=E(x,y)∼(X,Y)[DKL(pT∥pSθ)(y∣x)].L_{\text{SD}}(\theta)=\mathbb{E}_{(x,y)\sim(X,Y)} \Big[D_{\mathrm{KL}}\!\big(p_T\parallel p_S^\theta\big)(y|x)\Big]. LSD(θ)=E(x,y)∼(X,Y)[DKL(pT∥pSθ)(y∣x)].
好处:利用“软标签”提供的暗知识(非目标 token 的相对概率)。梯度仍是 pS−pTp_S-p_TpS−pT 的形态,但 pTp_TpT 不再是 one-hot。
局限:仍在固定序列上训练(可能是人工真值或老师生成的序列),分布失配依旧。
3. On-Policy KD:让学生在自己生成的序列上学
为解决分布失配,论文把“期望”从固定 (X,Y)(X,Y)(X,Y) 换成学生策略下的输出序列,得到(式 (4)):
LOD(θ)=Ex∼X[Ey∼pS(⋅∣x)[DKL(pT∥pSθ)(y∣x)]].L_{\text{OD}}(\theta)=\mathbb{E}_{x\sim X}\Big[\mathbb{E}_{y\sim p_S(\cdot|x)} \big[D_{\mathrm{KL}}(p_T\parallel p_S^\theta)(y|x)\big]\Big]. LOD(θ)=Ex∼X[Ey∼pS(⋅∣x)[DKL(pT∥pSθ)(y∣x)]].
关键实现细节:不对学生的采样分布反传(只在内层 KL 里更新 θ\thetaθ),可避免 REINFORCE 式高方差,训练更稳定高效。直观地说:学生先按当前策略走一遍,把“走错的地方”交给老师打分,再按 KL 梯度把这些状态上的 logits 拉回去。
与模仿学习的联系:这相当于 DAgger 风格的在线收集+专家纠正。
4. 选择更合适的“散度”DDD:从 KL 到广义 JSD
前向 KL 要求学生“覆盖老师的全部支持集”,容量不足时会把概率“摊薄”到老师几乎不选的 token 上,易造成幻觉;反向 KL 则更“择众”,只贴老师高概率 token,减少“离谱”但可能牺牲多样性。论文采用广义 JSD 在两者间连续插值(式 (1)):
DJSD(β)(P∥Q)=βDKL(P∥βP+(1−β)Q)+(1−β)DKL(Q∥βP+(1−β)Q).D_{\mathrm{JSD}(\beta)}(P\parallel Q)= \beta\, D_{\mathrm{KL}}\!\big(P\parallel \beta P+(1-\beta)Q\big)+ (1-\beta)\, D_{\mathrm{KL}}\!\big(Q\parallel \beta P+(1-\beta)Q\big). DJSD(β)(P∥Q)=βDKL(P∥βP+(1−β)Q)+(1−β)DKL(Q∥βP+(1−β)Q).
当 β→0\beta\to 0β→0 时,1βDJSD(β)→DKL(P∥Q)\tfrac{1}{\beta}D_{\mathrm{JSD}(\beta)} \to D_{\mathrm{KL}}(P\parallel Q)β1DJSD(β)→DKL(P∥Q);β→1\beta\to 1β→1 时更接近反向 KL 的行为。这样就能按任务、温度、容量调节“覆盖 vs. 模式寻求”的折衷。
经验指引:不同任务/采样温度的最优散度不同;且很多实验里,**纯 on-policy(学生样本占比 100%)**的效果最好。
5. GKD:统一“在哪些序列上学”和“用什么散度学”
把“数据来源”(固定数据 vs 学生自采样)与“散度种类”(前向/反向 KL、JSD(β)…)统一起来,得到广义 KD(式 (GKD)):
LGKD(θ)=(1−λ)E(x,y)∼(X,Y) [D(pT∥pSθ)(y∣x)]+λEx∼XEy∼pS(⋅∣x) [D(pT∥pSθ)(y∣x)].\!\!\!\!L_{\text{GKD}}(\theta) =(1-\lambda)\,\mathbb{E}_{(x,y)\sim(X,Y)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big] +\lambda\,\mathbb{E}_{x\sim X}\mathbb{E}_{y\sim p_S(\cdot|x)}\!\!\big[D(p_T\parallel p_S^\theta)(y|x)\big]. LGKD(θ)=(1−λ)E(x,y)∼(X,Y)[D(pT∥pSθ)(y∣x)]+λEx∼XEy∼pS(⋅∣x)[D(pT∥pSθ)(y∣x)].
- λ=0\lambda=0λ=0:退化为监督 KD;λ=1\lambda=1λ=1:纯 on-policy KD;
- DDD 可取前向/反向 KL 或 JSD(β\betaβ)。
实现上,对采样过程仍不反传。论文还给了Algorithm 1:每步按 λ\lambdaλ 抛硬币决定用哪类数据,再最小化该 batch 的散度。
一眼看懂的梯度形态(抽象)
忽略采样反传后,
∇θLGKD=E[1Ly∑n∇θD(pT(⋅)∥pSθ(⋅))⏟如前向 KL 时 ∝pS−pT],\nabla_\theta L_{\text{GKD}} = \mathbb{E}\!\left[\frac{1}{L_y}\sum_n \underbrace{\nabla_\theta D\!\big(p_T(\cdot)\parallel p_S^\theta(\cdot)\big)}_{\text{如前向 KL 时 }\;\propto\;p_S-p_T}\right], ∇θLGKD=ELy1n∑如前向 KL 时 ∝pS−pT∇θD(pT(⋅)∥pSθ(⋅)),
唯一差别在于这个期望是对哪种序列分布取(由 λ\lambdaλ 和是否 on-policy 决定),以及**DDD** 的具体形式。
6. 再进一步:把 RL 目标和 On-Policy GKD 并列优化
很多真实目标(如事实一致性)是不可导/非似然的。论文把策略梯度的奖励项与蒸馏正则项并列,给出(式 (5)):
Ex∼X[(1−α)Ey∼pSθ[r(y)]−αEy∼pS(⋅∣x)D(pT∥pSθ)(y∣x)].\mathbb{E}_{x\sim X}\Big[(1-\alpha)\,\mathbb{E}_{y\sim p_S^\theta}[r(y)] -\alpha\,\mathbb{E}_{y\sim p_S(\cdot|x)}D(p_T\parallel p_S^\theta)(y|x)\Big]. Ex∼X[(1−α)Ey∼pSθ[r(y)]−αEy∼pS(⋅∣x)D(pT∥pSθ)(y∣x)].
- 第一项:标准 REINFORCE/策略梯度的RL 目标;
- 第二项:on-policy 蒸馏正则,把策略往老师靠,以防 RL 走偏;
- α\alphaα 控制 RL 与蒸馏的权衡(α=1\alpha=1α=1 退化为仅蒸馏)。这与 RLHF 里常见的“KL 正则”相似,但这里是向老师而不是向初始策略收缩。