当前位置: 首页 > backend >正文

【系列07】端侧AI:构建与部署高效的本地化AI模型 第6章:知识蒸馏(Knowledge Distillation

第6章:知识蒸馏(Knowledge Distillation)

在构建端侧AI模型时,我们常常面临一个两难的局面:一方面需要大模型的强大性能,另一方面又必须满足端侧设备对模型体积和计算效率的要求。知识蒸馏是一种优雅的解决方案,它允许我们用一个大型的、性能优越的“教师模型”来指导一个小型、高效的“学生模型”的学习,从而让学生模型在保持轻量化的同时,获得接近教师模型的性能。


什么是知识蒸馏?

知识蒸馏的核心思想是转移知识。它不是简单地让学生模型去学习标注好的“硬标签”(hard labels),而是让它去学习教师模型的“软标签”(soft labels)。

  • 硬标签:指数据集中明确的类别标签,例如一张图片是“猫”或“狗”。学生模型的目标是尽可能地预测出正确的硬标签。
  • 软标签:指教师模型对每个类别的预测概率分布。例如,教师模型不仅会预测图片是“猫”,还会给出“狗”的概率是0.05,“老虎”的概率是0.02。这个概率分布包含了比单一硬标签更丰富的知识,因为它体现了不同类别之间的相似性和关系。

知识蒸馏的原理就是通过损失函数让学生模型的预测概率分布尽可能地接近教师模型的预测概率分布。


如何用大模型(教师模型)指导小模型(学生模型)的学习

知识蒸馏的训练过程可以概括为以下步骤:

  1. 选择教师模型:首先,你需要一个已经训练好的、性能强大的模型,作为你的教师。这个模型通常非常大,不适合直接部署。
  2. 选择学生模型:然后,你需要一个更小、更简单的模型,它将作为你的学生。这个模型需要有足够的容量来学习教师的知识。
  3. 构建训练流程:在训练阶段,你需要同时运行教师模型和学生模型。
    • 将同一批数据输入给教师模型,得到其预测的软标签(概率分布)。
    • 将同一批数据输入给学生模型,得到其预测的概率分布
  4. 计算损失函数:知识蒸馏的损失函数通常由两部分组成:
    • 蒸馏损失(Distillation Loss):用于衡量学生模型的概率分布与教师模型的软标签之间的差异。通常使用KL散度(Kullback-Leibler divergence)来计算。
    • 学生损失(Student Loss):用于衡量学生模型与真实硬标签之间的差异。通常使用交叉熵损失。
  5. 联合优化:通过联合优化这两个损失函数,学生模型不仅学习了硬标签,还从教师模型那里“继承”了更深层次的模式和知识。

实践:构建一个学生网络,并用一个预训练好的教师模型进行蒸馏

下面是一个使用PyTorch进行知识蒸馏的简化代码示例。我们将使用一个预训练的ResNet18作为教师,并构建一个更简单的网络作为学生。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models# 1. 定义教师模型 (使用预训练的ResNet18)
teacher_model = models.resnet18(pretrained=True)
teacher_model.eval() # 确保教师模型处于评估模式# 2. 定义学生模型 (一个简单的全连接网络)
class StudentNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1000, 500)self.relu = nn.ReLU()self.fc2 = nn.Linear(500, 100)def forward(self, x):x = self.relu(self.fc1(x))x = self.fc2(x)return xstudent_model = StudentNet()# 3. 定义损失函数
# 这里我们用两个损失函数,一个用于蒸馏,一个用于学生自己的学习
distillation_loss = nn.KLDivLoss(reduction="batchmean")
student_loss = nn.CrossEntropyLoss()# 4. 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)# 5. 训练循环 (简化版)
# 假设我们有一个dataloder
# for inputs, labels in dataloader:
#     # 将数据输入教师模型
#     with torch.no_grad():
#         teacher_outputs = teacher_model(inputs)#     # 将数据输入学生模型
#     student_outputs = student_model(inputs)#     # 计算损失
#     # 温度T是一个超参数,用于平滑概率分布
#     T = 2.0 
#     loss_distillation = distillation_loss(
#         F.log_softmax(student_outputs / T, dim=1),
#         F.softmax(teacher_outputs / T, dim=1)
#     )#     # 硬标签损失
#     loss_student = student_loss(student_outputs, labels)#     # 联合损失,通常会给两个损失分配权重
#     alpha = 0.5
#     total_loss = alpha * loss_distillation + (1 - alpha) * loss_student#     # 反向传播和优化
#     optimizer.zero_grad()
#     total_loss.backward()
#     optimizer.step()

通过这样的训练流程,学生模型不仅学习了如何正确分类,还从教师模型的“软标签”中学习到了类别之间的微妙关系,从而在更小的体量下实现了更好的性能。

http://www.xdnf.cn/news/19235.html

相关文章:

  • 监听nacos配置中心数据的变化
  • vector的学习和模拟
  • 桌面GIS软件添加设置牵引文字标注
  • Fortran二维数组去重(unique)算法实战
  • 电子健康记录风险评分与多基因风险评分的互补性与跨系统推广性研究
  • 福彩双色球第2025100期篮球号码分析
  • GESP5级2024年03月真题解析
  • Coze源码分析-API授权-获取令牌列表-后端源码
  • UNet改进(36):融合FSATFusion的医学图像分割
  • TensorFlow 面试题及详细答案 120道(71-80)-- 性能优化与调试
  • Next.js 快速上手指南
  • 数值分析——算法的稳定性
  • 【ACP】2025-最新-疑难题解析- 练习二汇总
  • 文档转换总出错?PDF工具免费功能实测
  • Docker 部署深度网络模型(Flask框架思路)
  • Intellij IDEA社区版(下载安装)
  • 项目管理方法全流程解析
  • HarmonyOS 持久化存储:PersistentStorage 实战指南
  • 详解推测性采样加速推理的算法逻辑
  • nginx配置websock请求,wss
  • java中的VO、DAO、BO、PO、DO、DTO
  • 【重学 MySQL】九十三、MySQL的字符集的修改与底层原理详解
  • 项目管理和产品管理的区别
  • 【gflags】安装与使用
  • 2025 批量下载雪球和东方财富帖子和文章导出excel和pdf
  • 一体化步进伺服电机在视觉检测设备中的应用案例
  • 弱内存模型和强内存模型架构(Weak/Strong Memory Model)
  • vue3多个el-checkbox勾选框设置必选一个
  • 一款支持动态定义路径的JAVA内存马维权工具Agenst
  • 科普文章:广告技术平台的盈利模式全景