当前位置: 首页 > news >正文

KL Loss

背景

KL Loss主要监督的是模型输出分布 VS 目标分布 之间的相似性
它不直接监督位置、速度等数值,而是监督模型「认为哪种可能性更大」是否和目标一致。
在多模态预测、知识蒸馏、策略学习中尤为重要。

KL 散度主要监督什么?

项目监督内容应用场景
分布相似性模型输出的概率分布(预测) vs 目标分布(通常是软标签)知识蒸馏、轨迹分布、行为克隆等
不确定性建模模型输出多个选择的分布(如多轨迹) vs 真值分布(soft target)轨迹预测、多模态输出
知识对齐学生网络预测分布 vs 教师网络的 soft 分布蒸馏
行为模仿/规划策略模型生成的动作分布 vs 专家动作分布模仿学习、策略学习

具体例子

  1. 知识蒸馏(Knowledge Distillation)

监督:


KL(Teacher(logits).softmax || Student(logits).softmax)

目标:让学生网络模仿教师网络输出的“概率分布”,而不是 hard label。

  1. 轨迹预测(Trajectory Prediction)
如果模型预测多种未来轨迹,每种轨迹有一个概率(例如多模态轨迹):predicted_probs = [0.6, 0.3, 0.1]
ground_truth_probs = [1.0, 0.0, 0.0]  # one-hot or soft label from expertKL(predicted || ground_truth)
  1. 行为克隆(Behavior Cloning)/模仿学习
    如果从专家(如人类或 rule-based agent)采样得到 soft policy 分布,模型输出 policy logits:
expert_policy = [0.7, 0.2, 0.1]
model_output = logits → softmax → [0.4, 0.4, 0.2]loss = KL(expert_policy || model_output)

目标:让模型模仿专家的策略分布(而不是只学最优动作)。

最基础的手写 KL 散度 loss (batch-wise)

假设:

p_target 是目标分布(通常来自 ground truth,已经是 soft label,如 one-hot 或 softmax)

q_pred 是模型输出分布(经过 softmax 或 log_softmax 之后)

import torch
import torch.nn.functional as Fdef kl_loss_manual(log_q, p):"""手动实现的KL散度:KL(p || q)参数:- log_q: 模型输出的对数概率分布(log_softmax后的)- p: 目标分布(soft label 或 one-hot)返回:- 平均 KL 散度 loss"""kl = p * (torch.log(p + 1e-10) - log_q)  # 避免 log(0)return kl.sum(dim=-1).mean()
# 模拟一个 batch,有3个样本,每个是3类分类任务
logits = torch.tensor([[2.0, 1.0, 0.1],[1.5, 2.0, 0.5],[0.1, 0.2, 3.0]])# 模型输出的 log_softmax
log_q = F.log_softmax(logits, dim=1)# 假设目标是 one-hot(可以是 soft label)
p = torch.tensor([[1.0, 0.0, 0.0],[0.0, 1.0, 0.0],[0.0, 0.0, 1.0]])loss = kl_loss_manual(log_q, p)
print("KL Loss:", loss.item())
http://www.xdnf.cn/news/1421497.html

相关文章:

  • Python OpenCV图像处理与深度学习:Python OpenCV图像滤波入门
  • [系统架构设计师]论文(二十三)
  • 基于SpringBoot+MYSQL开发的师生成果管理系统
  • 美术馆预约小程序|基于微信小程序的美术馆预约平台设计与实现(源码+数据库+文档)
  • zotero.sqlite已损坏
  • 第9篇:监控与运维 - 集成Actuator健康检查
  • 『C++成长记』vector模拟实现
  • 车载总线架构 --- 车载LIN总线传输层概述
  • 百胜软件获邀出席第七届中国智慧零售大会,智能中台助力品牌零售数智变革
  • C++ 虚继承:破解菱形继承的“双亲困境”
  • 【macOS】垃圾箱中文件无法清理的--特殊方法
  • Linux | 走进网络世界:MAC、IP 与通信的那些事
  • PyTorch 实战(3)—— PyTorch vs. TensorFlow:深度学习框架的王者之争
  • mysql中如何解析某个字段是否是中文
  • 攻防演练笔记
  • Frida Hook API 转换/显示堆栈
  • 【数学建模学习笔记】缺失值处理
  • 数学分析原理答案——第七章 习题13
  • 文件夹上传 (UploadFolder)
  • crypto-babyrsa(2025YC行业赛)
  • 【系统架构师设计(8)】需求分析之 SysML系统建模语言:从软件工程到系统工程的跨越
  • 【机器学习学习笔记】numpy基础2
  • 基于 HTML、CSS 和 JavaScript 的智能图像边缘检测系统
  • ESB 走向黄昏,为什么未来属于 iPaaS?
  • 【第十一章】Python 队列全方位解析:从基础到实战
  • 计算机网络技术(四)完结
  • 9月1日
  • 8Lane V-by-One HS LVDS FMC Card
  • 【STM32】贪吃蛇 [阶段 8] 嵌入式游戏引擎通用框架设计
  • IO进程线程;标准io;文件IO;0901