当前位置: 首页 > java >正文

Python 大模型知识蒸馏详解,知识蒸馏大模型,大模型蒸馏代码实战,LLMs knowledge distill LLM

一、原理

将待压缩的模型作为教师模型,将体积更小的模型作为学生模型,让学生模型在教师模型的监督下进行优化,将学生模型学习到教师模型的概率分布,通过kl散度进行控制。

二、方法

对于大模型的知识蒸馏,主要分为两种:

其一、黑盒知识蒸馏。

使用大模型生成数据,通过这些数据去微调更小的模型,来达到蒸馏的目的。缺点是蒸馏效率低,优点是实现简单。

其二、白盒知识蒸馏。

获取学生模型和教师模型的输出概率分布(或者中间隐藏层的概率分布),通过kl散度将学生模型的概率分布向教师模型对齐。 下面主要介绍和测试白盒知识蒸馏: 白盒知识蒸馏主要在于模型分布的对齐,模型分布对齐主要依赖kl散度,对于kl散度的使用又有如下几种方式:

(一)、前向kl散度。

也就是我们经常说的kl散度。


p为教师模型的概率分布,q为学生模型的概率分布,minillm论文中提到前向kl散度可能会使学生模型高估教师模型中概率比较低的位置,结合公式来看,当p增大时,为了使得kl散度小,则q也需要增大,但是当p趋于0时,无论q取任何值,kl散度都比较小,因为此时p(x)log((p(x)/q(x)))的大小主要受p(x)控制,这样起不到优化q分布的效果,可能会使q分布高估p分布中概率低的位置。 下图展示了前向kl散度的拟合情况,前向kl散度是一种均值搜索,更倾向于拟合多峰 

(二)、反向kl散度。

为了缓解前向kl散度的缺点,提出了反向kl散度。
 

rkl_formula


p为教师模型的概率分布,q为学生模型的概率分布,当p趋于零时,为了使kl散度小,q也需趋于0。 minillm论文中说对于大模型的知识蒸馏,反向kl散度优于前向kl散度,但是也有其他论文说反向kl散度不一定比前向kl散度更优,实际选择中,可能要基于实验驱动。 反向kl散度是一种模式搜索,更倾向于拟合单个峰 

 

(三)、偏向前kl散度。

对学生模型的分布和教师模型的分布进行加权作为学生模型的分布。

(四)、偏向反kl散度。

对学生模型的分布和教师模型的分布进行加权作为教师模型的分布。

三、测试

qwen2.5-3b作为教师模型,qwen2.5-0.5b作为学生模型
流程如下:
1、将qwen2.5-3b模型在指定数据集上微调(训练数据5000条,测试数据1000条,测试准确度为81.1%)
2、探索如下三种方案下的蒸馏效果(均使用前向kl散度):
2.1 不微调学生模型+kl散度损失
蒸馏1个epoch,准确度70.5%
蒸馏2个epoch,准确度73%
2.2 微调学生模型(模型准确度80.3%)+kl散度损失
蒸馏2个epoch,准确度61.9%
2.3 不微调学生模型+kl散度损失和交叉熵损失加权
蒸馏2个epoch,70.5%
3、上述实验中只使用kl散度的效果最好,如下实验中使用kl散度的变种进行测试,经过测试,效果都不如前向kl散度效果好。
3.1 反向kl散度
准确率只有54%
3.2 偏向前向kl散度
损失下降异常,效果很差,不断重复输出。
由于资源和时间的限制,所有测试均保持相同的超参数,未针对不同损失设置不同超参数。

四、代码实战

train.py

from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from peft import LoraConfig, get_peft_model, TaskType
from peft import PeftModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments
from dataset import SFTDataset
from utils import compute_fkl, compute_rkl, compute_skewed_fkl, compute_skewed_rklclass KGTrainer(Trainer):def __init__(self,model = None,teacher_model = None,if_use_entropy = False,args = None,data_collator = None, train_dataset = None,eval_dataset = None,tokenizer = None,model_init = None, compute_metrics = None, callbacks = None,optimizers = (None, None), preprocess_logits_for_metrics = None,):super().__init__(model,args,data_collator,train_dataset,eval_dataset,tokenizer,model_init,compute_metrics,callbacks,optimizers,preprocess_logits_for_metrics,)self.teacher_model = teacher_modelself.if_use_entropy = if_use_entropydef compute_loss(self, model, inputs, return_outputs=False):outputs = model(**inputs)with torch.no_grad():teacher_outputs = self.teacher_model(**inputs)loss = outputs.losslogits = outputs.logitsteacher_logits = teacher_outputs.logits# 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断if logits.shape[-1] != teacher_logits.shape[-1]:# gap = teacher_logits.shape[-1] - logits.shape[-1]# if gap > 0:#     pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)#     logits = torch.cat([logits, pad_logits], dim=-1)teacher_logits = teacher_logits[:, :, :logits.shape[-1]]labels = inputs['labels']kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0)if self.if_use_entropy:loss_total = 0.5 * kl + 0.5 * losselse:loss_total = klreturn (loss_total, outputs) if return_outputs else loss_totalif __name__ == '__main__':# 学生模型model = AutoModelForCausalLM.from_pretrained("Qwen2.5-0.5B-Instruct")lora_config = LoraConfig(r=8,  lora_alpha=256,  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],lora_dropout=0.1, task_type=TaskType.CAUSAL_LM)# 使用lora方法训练model = get_peft_model(model, lora_config)model.cuda()print(model.print_trainable_parameters())tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-Instruct")# 教师模型,在给定数据上通过lora微调teacher_model = AutoModelForCausalLM.from_pretrained("Qwen2.5-7B-Instruct")# 是否加载lora模型lora_path = 'qwen2.5_7b/lora/sft'teacher_model = PeftModel.from_pretrained(teacher_model, lora_path)teacher_model.cuda()teacher_model.eval()args = TrainingArguments(output_dir='./results', num_train_epochs=10, do_train=True, per_device_train_batch_size=2,gradient_accumulation_steps=16,logging_steps=10,report_to='tensorboard',save_strategy='epoch',save_total_limit=10,bf16=True,learning_rate=0.0005,lr_scheduler_type='cosine',dataloader_num_workers=8,dataloader_pin_memory=True)data_collator = DefaultDataCollator()dataset = SFTDataset('data.json', tokenizer=tokenizer, max_seq_len=512)trainer = KGTrainer(model=model,teacher_model=teacher_model, if_use_entropy = True,args=args, train_dataset=dataset, tokenizer=tokenizer, data_collator=data_collator)# 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为Truetrainer.train(resume_from_checkpoint=False)trainer.save_model('./saves')trainer.save_state()

dataset.py 

import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
import pandas as pdfrom torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import  PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfigclass SFTDataset(Dataset):def __init__(self, data_path, tokenizer, max_seq_len):super().__init__()self.data_path = data_pathself.tokenizer = tokenizerself.max_seq_len = max_seq_lenself.padding_id = tokenizer.pad_token_idwith open(self.data_path, 'r', encoding='utf-8') as f:self.data = json.load(f)def __len__(self):return len(self.data)    def __getitem__(self, index):line = self.data[index]instruction_text = line['instruction']input_text = line['input']output_text = line['output']query = instruction_text + input_textanswer = output_text + self.tokenizer.eos_tokenmessages = []messages.append({'role': 'user', 'content': query})   prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) prompt_input_ids = self.tokenizer.encode(prompt)answer_input_ids = self.tokenizer.encode(answer)input_ids = prompt_input_ids + answer_input_idslabels = [-100] * len(prompt_input_ids) + answer_input_idsattention_mask = [1] * len(input_ids)text_len = len(input_ids)if text_len > self.max_seq_len:input_ids = input_ids[:self.max_seq_len]labels = labels[:self.max_seq_len]attention_mask = attention_mask[:self.max_seq_len]else:input_ids = input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - text_len)labels = labels + [-100] * (self.max_seq_len - text_len)attention_mask = attention_mask + [0] * (self.max_seq_len - text_len)# input_ids = input_ids[:-1]# labels = labels[1:]return {'input_ids': torch.tensor(input_ids), 'attention_mask':torch.tensor(attention_mask), 'labels': torch.tensor(labels)}

utils.py

import torch# 计算前向kl散度
def compute_fkl(logits, teacher_logits, target, padding_id,reduction="sum",temp = 1.0, ):logits = logits / tempteacher_logits = teacher_logits / templog_probs = torch.log_softmax(logits, -1, dtype=torch.float32)teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)kl = (teacher_probs * (teacher_log_probs - log_probs)) kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()return kl
# 计算反向kl散度
def compute_rkl(logits, teacher_logits, target, padding_id,reduction="sum", temp = 1.0):logits = logits / tempteacher_logits = teacher_logits / tempprobs = torch.softmax(logits, -1, dtype=torch.float32)log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)kl = (probs * (log_probs - teacher_log_probs))kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()return kl# 计算偏向前kl散度
def compute_skewed_fkl(logits, teacher_logits, target, padding_id, reduction="sum", temp = 1.0,skew_lambda = 0.1):logits = logits / tempteacher_logits = teacher_logits / tempprobs = torch.softmax(logits, -1, dtype=torch.float32)teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * probsmixed_log_probs = torch.log(mixed_probs)teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)kl = (teacher_probs * (teacher_log_probs - mixed_log_probs))kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()return kl
# 计算偏向反kl散度    
def compute_skewed_rkl(logits, teacher_logits, target,padding_id,reduction="sum", temp = 1.0,skew_lambda = 0.1
):logits = logits / tempteacher_logits = teacher_logits / tempprobs = torch.softmax(logits, -1, dtype=torch.float32)teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * probsmixed_log_probs = torch.log(mixed_probs)log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)kl = (probs * (log_probs - mixed_log_probs))kl = kl.sum(-1)if reduction == "sum":pad_mask = target.eq(padding_id)kl = kl.masked_fill_(pad_mask, 0.0)kl = kl.sum()return kl

五、data格式

instruction_text = line['instruction']
input_text = line['input']
output_text = line['output'] 

{

'instruction':'很厉害的专家',

'input':'写一首诗',

'output':'巴拉巴拉小魔仙',

http://www.xdnf.cn/news/8438.html

相关文章:

  • stm32上拉电阻,1K,4.7K,5.6K,10K怎么选?
  • 职业规划:动态迭代的系统化路径
  • javaScirpt学习第五章(函数)-第一部分
  • 【Web前端】JavaScript入门与基础(一)
  • WebStorm 高效快捷方式全解析
  • 11.5 Python+LangGraph智能代理开发:节点设计与业务流实战全解析
  • 【通用智能体】smolagents/open_deep_research:面向开放式研究的智能体开发框架深度解析
  • Vue3 对象转换
  • 七:操作系统文件系统之目录结构
  • 【Elasticsearch】创建别名的几种方式
  • 算法打卡第五天
  • 三、如何优化opengl在gpu上的渲染性能
  • 「EMD/EEMD/VMD 信号分解方法 ——ECG信号处理-第十四课」2025年5月23日
  • 每日Prompt:虚拟世界游
  • Linux性能监控:工具与最佳实践
  • Vue.js教学第十二章:Vue Router实战指南(二)
  • C++ 日志系统实战第六步:性能测试
  • Day 29 训练
  • 永磁同步电机控制算法-滑模反馈线性化控制器
  • 红队攻防实践:15大漏洞原理与复现全解析
  • 【agent】简历信息提取智能体
  • AGV(自动导引车)通信协议及通信链路性能需求分析
  • 力扣HOT100之图论:994. 腐烂的橘子
  • 二、详细解释OpenGL图形管线中顶点处理阶段的工作原理
  • day57—快速(选择/排序)—数组中的第 K 个最大元素(LeetCode-215)
  • 国家网络身份认证公共服务管理办法
  • nginx配置跨域请求,后台不用配置啦,完美
  • vue 水印组件
  • 【Dv3Admin】插件 dv3admin_chatgpt 优化支持多种启动方式实现SSE效果
  • QT之巧用对象充当信号接收者