当前位置: 首页 > news >正文

大模型蒸馏(distillation)---从DeepseekR1-1.5B到Qwen-2.5-1.5B蒸馏

 

目录

1.1 蒸馏目标

2 环境准备

2.1依赖库安装

2.2 硬件要求

2.3 模型与数据集下载

2.3.1 教师模型下载

2.3.2 学生模型下载

 2.3.3 数据集准备或下载

 3.过程日志

 4. 模型加载与配置

4.1 加载教师模型

4.2 加载学生模型

4.3 数据预处理函数  

 4.4 数据收集器

4.5 定义训练参数

4.6 定义蒸馏配置

4.7 定义训练配置

4.8 创建蒸馏器 

4.9 开始蒸馏 

 5. 完整代码

6.结合上述内容和TextBrewer,自己重新整理了一遍代码,仅供参考:

1.1 蒸馏目标

将 DeepSeek 的推理能力迁移到 Qwen-2.5;

确保学生模型与 Qwen-2.5 的原始功能(如对话、多语言支持)兼容。

2 环境准备

2.1依赖库安装

pip install torch torchvision transformers datasets2.2
pip install accelerate # 加速分布式训练
pip install evaluate # 评估指标

2.2 硬件要求

GPU:建议使用单张或多张 NVIDIA GPU(如 V100、A100)建议至少 24GB。

CUDA:安装与 PyTorch 兼容的 CUDA 版本。

2.3 模型与数据集下载

2.3.1 教师模型下载

Qwen-2.5-1.5B从huggingface 下载,离线下载方式(从https://hf-mirror.com离线下载):

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download Qwen/Qwen2.5-1.5B --local-dir ./models/qwen2.5-1.5B --local-dir-use-symlinks False

2.3.2 学生模型下载

Qwen-2.5-1.5B

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download Qwen/Qwen2.5-1.5B --local-dir ./models/qwen2.5-1.5B --local-dir-use-symlinks False

 2.3.3 数据集准备或下载

建议使用大规模文本数据集(如 wikitex、Wikipedia、BooksCorpus、OpenWebText 等)。离线下载地址(从https://www.kaggle.com/datasets/jayanthbontha/wikitext下载)

 3.过程日志

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")

 4. 模型加载与配置

4.1 加载教师模型

# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)

4.2 加载学生模型

# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)

4.3 数据预处理函数  

dataset.map() 是 Hugging Face datasets 库中用于对数据集进行批量预处理的核心方法。当 batched=True 时,它会将数据集分批(batch)传递给 preprocess_function,而不是逐个样本处理。这种批量处理方式效率更高,尤其适合大规模数据集。

# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)

 4.4 数据收集器

DataCollatorForLanguageModeling 是 Hugging Face transformers 库中的一个数据整理类(Data Collator),用于在训练语言模型(如 BERT、GPT 等)时动态生成训练样本。它可以根据任务需求(如掩码语言模型(MLM)或因果语言模型(CLM))对输入数据进行预处理。

# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)

mlm(关键参数):作用:控制是否启用**掩码语言模型(MLM)**模式。

mlm=True:随机掩码输入中的部分 token(如 BERT 训练方式),生成 [MASK] 标记。

mlm=False:禁用掩码,适用于因果语言模型(CLM)(如 GPT 训练方式),输入和标签为原始 token 序列。

4.5 定义训练参数

# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)

4.6 定义蒸馏配置

# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度hard_label_weight=0.5,  # 真实标签损失权重kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)intermediate_matches=[  # 中间层匹配配置{"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失}])

4.7 定义训练配置

# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择log_dir="./logs",                                     # 日志目录output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态))

4.8 创建蒸馏器 

# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)

4.9 开始蒸馏 

# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)args=training_args,   # 训练参数(如学习率、批次大小、设备等)train_dataset=train_dataset,  # 训练数据集(包含输入和标签)eval_dataset=eval_dataset,    # 验证数据集(用于评估模型性能)data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理)# 开始模型训练trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等logger.info("Training finished")  # 记录训练结束日志

 5. 完整代码

import osimport torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, \TrainingArguments
from textbrewer import GeneralDistiller, TrainingConfig, DistillationConfig
from datasets import load_dataset
import logging# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)# 准备数据集
datasets_name = os.path.join(current_script_dir, "../models/Dataset/wikitext-2-raw/")  # 确保模型名称正确
data_files = {"train": datasets_name+"wiki.train.raw","test": datasets_name+"wiki.test.raw"
}
logger.info(f"Loading dataset from local files: {data_files}")
dataset = load_dataset("text", data_files=data_files)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度hard_label_weight=0.5,  # 真实标签损失权重kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)intermediate_matches=[  # 中间层匹配配置{"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失}]
)# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择log_dir="./logs",                                     # 日志目录output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态)
)# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)args=training_args,  # 训练参数(如学习率、批次大小、设备等)train_dataset=train_dataset,  # 训练数据集(包含输入和标签)eval_dataset=eval_dataset,  # 验证数据集(用于评估模型性能)data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理)# 开始模型训练trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等trainer.save_model()logger.info("Training finished")  # 记录训练结束日志
复制代码

参考:

模型蒸馏(Distillation)案例--从DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏 - InProsperity - 博客园

模型蒸馏(Distillation)案例--从DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏-CSDN博客

6.结合上述内容和TextBrewer,自己重新整理了一遍代码,仅供参考:

import os
import torch
import logging
from transformers import (AutoModelForCausalLM,AutoTokenizer,DataCollatorForLanguageModeling,get_linear_schedule_with_warmup
)
from textbrewer import GeneralDistiller, TrainingConfig, DistillationConfig
from datasets import load_dataset
from torch.optim import AdamW# 配置日志
logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler("distillation.log"),logging.StreamHandler()]
)
logger = logging.getLogger(__name__)# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")# ======================
# 1. 加载模型和Tokenizer
# ======================
def load_models_and_tokenizers():"""加载教师和学生模型"""# 教师模型 (DeepSeek-R1 1.5B)teacher_model_name = "deepseek-ai/deepseek-r1-1.5b"logger.info(f"Loading teacher model: {teacher_model_name}")teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)# 学生模型 (Qwen 1.5B)student_model_name = "Qwen/Qwen1.5-1.8B"logger.info(f"Loading student model: {student_model_name}")student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)student_model = AutoModelForCausalLM.from_pretrained(student_model_name,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)# 对齐tokenizer(关键步骤!)if teacher_tokenizer.vocab != student_tokenizer.vocab:logger.warning("Tokenizers not aligned, adding special tokens...")student_tokenizer.add_special_tokens({'pad_token': '[PAD]'})student_model.resize_token_embeddings(len(student_tokenizer))return teacher_model, student_model, teacher_tokenizer, student_tokenizer# ======================
# 2. 数据准备
# ======================
def prepare_data(student_tokenizer):"""加载并预处理数据"""# 加载数据集(示例使用wikitext)dataset = load_dataset("wikitext", "wikitext-2-raw-v1")# 预处理函数def preprocess_function(examples):return student_tokenizer(examples["text"],truncation=True,padding="max_length",max_length=512,return_tensors="pt")# 处理数据集train_dataset = dataset["train"].map(preprocess_function,batched=True,remove_columns=["text"])eval_dataset = dataset["validation"].map(preprocess_function,batched=True,remove_columns=["text"])# 数据收集器data_collator = DataCollatorForLanguageModeling(tokenizer=student_tokenizer,mlm=False)return train_dataset, eval_dataset, data_collator# ======================
# 3. 蒸馏配置
# ======================
def get_distillation_config():"""配置蒸馏参数"""return DistillationConfig(temperature=2.0,  # 初始温度temperature_scheduler=lambda x: max(0.5, 2.0 - 0.1 * x),  # 温度衰减hard_label_weight=0.3,  # 真实标签权重kd_loss_weight=0.7,  # 蒸馏损失权重kd_loss_type="ce",  # 交叉熵损失intermediate_matches=[{"layer_T": [6, 12, 18],  # 教师模型层"layer_S": [3, 6, 9],  # 学生模型层"feature": "hidden",  # 隐藏状态"loss": "cosine",  # 余弦相似度损失"weight": 0.5,"proj": ["linear", 1536, 1024]  # 维度投影},{"layer_T": [9, 15],"layer_S": [4, 7],"feature": "attention",  # 注意力矩阵"loss": "mse","weight": 0.3}])# ======================
# 4. 训练配置
# ======================
def get_training_config():"""配置训练参数"""return TrainingConfig(output_dir="./distill_output",device=device,fp16=torch.cuda.is_available(),gradient_accumulation_steps=4,ckpt_frequency=500,  # 每500步保存检查点log_steps=100,max_grad_norm=1.0,  # 梯度裁剪save_optimizer=False  # 为节省空间不保存优化器)# ======================
# 5. 优化器设置
# ======================
def get_optimizer(model):"""配置优化器和学习率调度"""optimizer = AdamW(model.parameters(),lr=5e-5,weight_decay=0.01)scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=500,num_training_steps=3000)return optimizer, scheduler# ======================
# 主函数
# ======================
def main():# 1. 加载模型和数据teacher_model, student_model, teacher_tokenizer, student_tokenizer = load_models_and_tokenizers()train_dataset, eval_dataset, data_collator = prepare_data(student_tokenizer)# 2. 配置蒸馏distill_config = get_distillation_config()train_config = get_training_config()# 3. 初始化蒸馏器distiller = GeneralDistiller(train_config=train_config,distill_config=distill_config,model_T=teacher_model,model_S=student_model,adaptor_T=None,  # 使用默认适配器adaptor_S=None)# 4. 准备优化器optimizer, scheduler = get_optimizer(student_model)# 5. 开始蒸馏logger.info("Starting distillation...")with distiller:distiller.train(optimizer=optimizer,scheduler=scheduler,train_dataset=train_dataset,eval_dataset=eval_dataset,batch_size=2,num_epochs=3,data_collator=data_collator,callback=None)# 6. 保存最终模型student_model.save_pretrained("./final_student_model")student_tokenizer.save_pretrained("./final_student_model")logger.info("Distillation completed!")if __name__ == "__main__":main()

另外,可以了解Text Generation WebUI,集成不同大模型进行推理,测试。

https://github.com/oobabooga/text-generation-webui

http://www.xdnf.cn/news/1193419.html

相关文章:

  • ARM SMMUv3控制器注册过程分析(八)
  • 二分函数 lower_bound upper_bound
  • 21-ospf多区域
  • 【Bluedroid】btif_av_sink_execute_service之服务器禁用源码流程解析
  • Apache Doris Data Agent 解决方案:开启智能运维与数据治理新纪元
  • 2025年入局苹果Vision Pro开发:从零到发布的完整路线图
  • LeetCode 刷题【15. 三数之和】
  • 如何关闭Windows自动更新?【图文详解】win10/win11关闭自动更新
  • CentOS 7 安装 MySQL 8.4.6(二进制包)指南
  • Linux——线程同步
  • CT、IT、ICT 和 DICT区别
  • 【架构】Docker简单认知构建
  • 【科研绘图系列】R语言绘制误差连线散点图
  • 秋招Day19 - 分布式 - 分布式事务
  • 生产环境使用云服务器(centOS)部署和使用MongoDB
  • Java操作Excel文档
  • opencv学习(图像金字塔)
  • 背包问题及 LIS 优化
  • 告别配置混乱!Spring Boot 中 Properties 与 YAML 的深度解析与最佳实践
  • C#编程基础:运算符与结构详解
  • 【Android】相对布局应用-登录界面
  • 2025.7.26字节掀桌子了,把coze开源了!!!
  • window下MySQL安装(三)卸载mysql
  • Fast_Lio 修改激光雷达话题
  • VLAN的划分(基于华为eNSP)
  • MySQL 8.0 OCP 1Z0-908 题目解析(37)
  • 尝试几道算法题,提升python编程思维
  • Linux内核设计与实现 - 课程大纲
  • LeetCode 1074:元素和为目标值的子矩阵数量
  • 使用Spring Boot创建Web项目