当前位置: 首页 > news >正文

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 采用PyTorch 内置函数F.kl_div的实现方式

知识蒸馏 - 基于KL散度的知识蒸馏 HelloWorld 示例 采用PyTorch 内置函数F.kl_div的实现方式

flyfish

kl_div 是 Kullback-Leibler Divergence的英文缩写。
其中,KL 对应提出该概念的两位学者(Kullback 和 Leibler)的姓氏首字母“div”是 divergence(散度)的缩写。

F.kl_div(logQ, P, reduction='sum') 等价于 torch.sum(P * (torch.log(P) - logQ))

import torch
import torch.nn.functional as F# 1. 定义示例输入(教师和学生的logits)
teacher_logits = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
student_logits = torch.tensor([[1.2, 2.1, 2.9], [3.8, 5.2, 6.1]], dtype=torch.float32)
T = 2.0  # 温度参数
batch_size = teacher_logits.size(0)# 2. 温度软化处理
teacher_scaled = teacher_logits / T
student_scaled = student_logits / T# 3. 计算分布
teacher_soft = F.softmax(teacher_scaled, dim=-1)  # 教师分布 P
student_log_soft = F.log_softmax(student_scaled, dim=-1)  # 学生对数分布 log Q# 4. 两种方式计算KL散度
# 方式1:手动计算(原始公式)
manual_kl = torch.sum(teacher_soft * (torch.log(teacher_soft) - student_log_soft)) / batch_size
manual_kl *= T**2  # 温度补偿# 方式2:使用PyTorch自带的F.kl_div
# 注意:F.kl_div(input=logQ, target=P, reduction='sum') 对应 sum(P*(logP - logQ))
torch_kl = F.kl_div(student_log_soft, teacher_soft, reduction='sum') / batch_size
torch_kl *= T**2  # 温度补偿# 5. 结果对比
print("===== 教师分布 P (softmax后) =====")
print(teacher_soft)
print("\n===== 学生对数分布 logQ (log_softmax后) =====")
print(student_log_soft)
print("\n===== KL散度计算结果 =====")
print(f"手动计算: {manual_kl.item():.6f}")
print(f"F.kl_div计算: {torch_kl.item():.6f}")
print(f"两者是否等价 (误差<1e-6): {torch.allclose(manual_kl, torch_kl, atol=1e-6)}")
===== 教师分布 P (softmax后) =====
tensor([[0.1863, 0.3072, 0.5065],[0.1863, 0.3072, 0.5065]])===== 学生对数分布 logQ (log_softmax后) =====
tensor([[-1.5909, -1.1409, -0.7409],[-1.8200, -1.1200, -0.6700]])===== KL散度计算结果 =====
手动计算: 0.008507
F.kl_div计算: 0.008507
两者是否等价 (误差<1e-6): True

说明:

1.输入设置:构造了教师和学生模型的logits(模拟不同的预测结果),并设置温度参数T=2.0

2.分布计算:
教师分布teacher_soft:通过softmax得到概率分布 PPP
学生对数分布student_log_soft:通过log_softmax得到 log⁡Q\log QlogQ

3.两种KL计算方式:

手动计算:严格按照公式 KL(P∥Q)=∑P⋅(log⁡P−log⁡Q)\text{KL}(P \parallel Q) = \sum P \cdot (\log P - \log Q)KL(PQ)=P(logPlogQ) 实现,除以批次大小后乘以温度平方补偿。

F.kl_div计算:直接调用PyTorch函数,注意参数顺序为(logQ, P),使用reduction='sum'确保与手动计算的求和逻辑一致。

4.等价性验证:通过torch.allclose检查两者结果是否在允许的浮点数误差范围内(1e-6)一致。

http://www.xdnf.cn/news/1235413.html

相关文章:

  • 标记-清除算法中的可达性判定与Chrome DevTools内存分析实践
  • Rust: 获取 MAC 地址方法大全
  • webrtv弱网-QualityScalerResource 源码分析及算法原理
  • 集成电路学习:什么是USB HID人机接口设备
  • Hertzbeat如何配置redis?保存在redis的数据是可读数据
  • PostgreSQL面试题及详细答案120道(21-40)
  • 腾讯人脸识别
  • 14.Redis 哨兵 Sentinel
  • C++中多线程和互斥锁的基本使用
  • [硬件电路-148]:数字电路 - 什么是CMOS电平、TTL电平?还有哪些其他电平标准?发展历史?
  • 本地环境vue与springboot联调
  • 2025年6月电子学会青少年软件编程(C语言)等级考试试卷(四级)
  • [硬件电路-143]:模拟电路 - 开关电源与线性稳压电源的详细比较
  • Ubuntu22.4部署大模型前置安装
  • webrtc弱网-QualityScaler 源码分析与算法原理
  • ubuntu apt安装与dpkg安装相互之间的关系
  • (一)全栈(react配置/https支持/useState多组件传递/表单提交/React Query/axois封装/Router)
  • 自动驾驶中的传感器技术18——Camera(9)
  • GitLab 代码管理平台部署及使用
  • Java基本技术讲解
  • PPT自动化 python-pptx - 9: 图表(chart)
  • 决策树学习全解析:从理论到实战
  • 【LeetCode刷题指南】--二叉树的后序遍历,二叉树遍历
  • PPT写作五个境界--仅供学习交流使用
  • 【1】WPF界面开发入门—— 图书馆程序:登录界面设计
  • 业务系统跳转Nacos免登录方案实践
  • web前端React和Vue框架与库安全实践
  • 【设计模式】4.装饰器模式
  • ThinkPHP5x,struts2等框架靶场复现
  • LLM - 智能体工作流设计模式