【系列07】端侧AI:构建与部署高效的本地化AI模型 第6章:知识蒸馏(Knowledge Distillation
第6章:知识蒸馏(Knowledge Distillation)
在构建端侧AI模型时,我们常常面临一个两难的局面:一方面需要大模型的强大性能,另一方面又必须满足端侧设备对模型体积和计算效率的要求。知识蒸馏是一种优雅的解决方案,它允许我们用一个大型的、性能优越的“教师模型”来指导一个小型、高效的“学生模型”的学习,从而让学生模型在保持轻量化的同时,获得接近教师模型的性能。
什么是知识蒸馏?
知识蒸馏的核心思想是转移知识。它不是简单地让学生模型去学习标注好的“硬标签”(hard labels),而是让它去学习教师模型的“软标签”(soft labels)。
- 硬标签:指数据集中明确的类别标签,例如一张图片是“猫”或“狗”。学生模型的目标是尽可能地预测出正确的硬标签。
- 软标签:指教师模型对每个类别的预测概率分布。例如,教师模型不仅会预测图片是“猫”,还会给出“狗”的概率是0.05,“老虎”的概率是0.02。这个概率分布包含了比单一硬标签更丰富的知识,因为它体现了不同类别之间的相似性和关系。
知识蒸馏的原理就是通过损失函数让学生模型的预测概率分布尽可能地接近教师模型的预测概率分布。
如何用大模型(教师模型)指导小模型(学生模型)的学习
知识蒸馏的训练过程可以概括为以下步骤:
- 选择教师模型:首先,你需要一个已经训练好的、性能强大的模型,作为你的教师。这个模型通常非常大,不适合直接部署。
- 选择学生模型:然后,你需要一个更小、更简单的模型,它将作为你的学生。这个模型需要有足够的容量来学习教师的知识。
- 构建训练流程:在训练阶段,你需要同时运行教师模型和学生模型。
- 将同一批数据输入给教师模型,得到其预测的软标签(概率分布)。
- 将同一批数据输入给学生模型,得到其预测的概率分布。
- 计算损失函数:知识蒸馏的损失函数通常由两部分组成:
- 蒸馏损失(Distillation Loss):用于衡量学生模型的概率分布与教师模型的软标签之间的差异。通常使用KL散度(Kullback-Leibler divergence)来计算。
- 学生损失(Student Loss):用于衡量学生模型与真实硬标签之间的差异。通常使用交叉熵损失。
- 联合优化:通过联合优化这两个损失函数,学生模型不仅学习了硬标签,还从教师模型那里“继承”了更深层次的模式和知识。
实践:构建一个学生网络,并用一个预训练好的教师模型进行蒸馏
下面是一个使用PyTorch进行知识蒸馏的简化代码示例。我们将使用一个预训练的ResNet18作为教师,并构建一个更简单的网络作为学生。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models# 1. 定义教师模型 (使用预训练的ResNet18)
teacher_model = models.resnet18(pretrained=True)
teacher_model.eval() # 确保教师模型处于评估模式# 2. 定义学生模型 (一个简单的全连接网络)
class StudentNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1000, 500)self.relu = nn.ReLU()self.fc2 = nn.Linear(500, 100)def forward(self, x):x = self.relu(self.fc1(x))x = self.fc2(x)return xstudent_model = StudentNet()# 3. 定义损失函数
# 这里我们用两个损失函数,一个用于蒸馏,一个用于学生自己的学习
distillation_loss = nn.KLDivLoss(reduction="batchmean")
student_loss = nn.CrossEntropyLoss()# 4. 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)# 5. 训练循环 (简化版)
# 假设我们有一个dataloder
# for inputs, labels in dataloader:
# # 将数据输入教师模型
# with torch.no_grad():
# teacher_outputs = teacher_model(inputs)# # 将数据输入学生模型
# student_outputs = student_model(inputs)# # 计算损失
# # 温度T是一个超参数,用于平滑概率分布
# T = 2.0
# loss_distillation = distillation_loss(
# F.log_softmax(student_outputs / T, dim=1),
# F.softmax(teacher_outputs / T, dim=1)
# )# # 硬标签损失
# loss_student = student_loss(student_outputs, labels)# # 联合损失,通常会给两个损失分配权重
# alpha = 0.5
# total_loss = alpha * loss_distillation + (1 - alpha) * loss_student# # 反向传播和优化
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()
通过这样的训练流程,学生模型不仅学习了如何正确分类,还从教师模型的“软标签”中学习到了类别之间的微妙关系,从而在更小的体量下实现了更好的性能。