基于 LoRA的广义知识蒸馏(GKD)训练
基于 LoRA的广义知识蒸馏(GKD)训练
flyfish
通过参数高效的 LoRA(低秩适应)技术,结合广义知识蒸馏(GKD)方法,让小尺寸的学生模型(如 Qwen2-0.5B-Instruct)高效学习大尺寸教师模型(如 Qwen2-1.5B-Instruct)的知识和能力,最终在减少计算资源消耗的前提下,提升小模型的对话性能,使其接近大模型的水平。
python examples/scripts/gkd.py \--model_name_or_path Qwen/Qwen2-0.5B-Instruct \ # 学生模型:小模型,待蒸馏的模型--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ # 教师模型:大模型,提供知识的模型--dataset_name trl-lib/chatbot_arena_completions \ # 训练数据集:对话竞技场数据(含高质量对话)--learning_rate 2e-4 \ # 学习率:LoRA微调通常用更大的学习率(全量训练一般2e-5)--per_device_train_batch_size 4 \ # 单设备训练批次大小--gradient_accumulation_steps 8 \ # 梯度累积步数:总批次=4*8=32(节省显存)--output_dir gkd-model \ # 模型保存路径--num_train_epochs 1 \ # 训练轮次--push_to_hub \ # 训练后推送到Hugging Face Hub--gradient_checkpointing \ # 启用梯度检查点:牺牲少量速度换显存--use_peft \ # 启用PEFT(参数高效微调)框架,这里用于LoRA--lora_r 64 \ # LoRA的秩(秩越低,参数量越少)--lora_alpha 16 # LoRA的缩放系数(控制更新幅度)
解析参数 → 加载模型 / Tokenizer → 加载数据集 → 初始化 GKD 训练器 → 执行训练 → 保存模型。
# 导入必要的库
# 加载数据集的工具
from datasets import load_dataset
# 加载分词器和生成配置的工具
from transformers import AutoTokenizer, GenerationConfig# 从trl库导入GKD相关的配置、训练器和工具
from trl import (GKDConfig, # GKD训练的核心配置类GKDTrainer, # GKD训练器,用于实现广义知识蒸馏LogCompletionsCallback, # 记录生成结果的回调函数,用于评估ModelConfig, # 模型相关配置(如LoRA参数、量化设置等)ScriptArguments, # 脚本级参数(如数据集路径、分裂等)TrlParser, # trl库专用的参数解析器get_kbit_device_map, # 获取量化模型的设备映射(自动分配GPU/CPU)get_peft_config, # 获取PEFT配置(如LoRA参数)get_quantization_config, # 获取量化配置(如4/8位量化)
)if __name__ == "__main__":# 初始化参数解析器,支持解析三类配置:脚本参数、GKD训练配置、模型配置parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))# 解析命令行参数,得到三个配置对象# script_args:数据集路径、分裂等脚本级参数# training_args:GKD训练的核心参数(学习率、批次大小等)# model_args:模型相关参数(LoRA配置、量化设置等)script_args, training_args, model_args = parser.parse_args_and_config()################# 模型与分词器配置################# 根据model_args获取量化配置(如4位/8位量化),用于减少显存占用quantization_config = get_quantization_config(model_args)# 定义学生模型的初始化参数model_kwargs = dict(revision=model_args.model_revision, # 模型版本(如特定commit哈希)trust_remote_code=model_args.trust_remote_code, # 是否信任模型的自定义代码(如非标准架构)attn_implementation=model_args.attn_implementation, # 注意力实现方式(如flash attention加速)torch_dtype=model_args.torch_dtype, # 数据类型(如float16/bfloat16,节省显存)# 启用梯度检查点时禁用缓存(两者冲突),否则启用缓存加速use_cache=False if training_args.gradient_checkpointing else True,# 量化时自动分配设备(GPU/CPU),非量化时不指定device_map=get_kbit_device_map() if quantization_config is not None else None,quantization_config=quantization_config, # 量化配置(如4位量化参数))# 将学生模型参数传递给训练配置training_args.model_init_kwargs = model_kwargs# 定义教师模型的初始化参数(与学生模型类似,但有细微差别)teacher_model_kwargs = dict(revision=model_args.model_revision,trust_remote_code=model_args.trust_remote_code,attn_implementation=model_args.attn_implementation,torch_dtype=model_args.torch_dtype,use_cache=True, # 教师模型仅用于推理,启用缓存加速生成device_map=get_kbit_device_map() if quantization_config is not None else None,quantization_config=quantization_config,)# 将教师模型参数传递给训练配置training_args.teacher_model_init_kwargs = teacher_model_kwargs# 加载分词器(与学生模型匹配,确保格式一致)tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, # 分词器路径(与学生模型相同)revision=model_args.model_revision,trust_remote_code=model_args.trust_remote_code,padding_side="left", # 左填充(生成任务常用,避免右填充影响生成逻辑))# 若分词器未定义pad_token,用eos_token代替(确保填充功能正常)if tokenizer.pad_token is None:tokenizer.pad_token = tokenizer.eos_token################# 数据集加载################# 加载指定数据集(如trl-lib/chatbot_arena_completions对话数据集)# script_args.dataset_name:数据集名称,script_args.dataset_config:数据集配置(如子数据集)dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)################# 训练初始化################# 初始化GKD训练器,核心组件trainer = GKDTrainer(model=model_args.model_name_or_path, # 学生模型路径(如Qwen/Qwen2-0.5B-Instruct)teacher_model=training_args.teacher_model_name_or_path, # 教师模型路径(如Qwen/Qwen2-1.5B-Instruct)args=training_args, # 训练配置(学习率、批次大小等)train_dataset=dataset[script_args.dataset_train_split], # 训练集(如dataset["train"])# 验证集(若启用评估)eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,processing_class=tokenizer, # 用于数据预处理的分词器peft_config=get_peft_config(model_args), # PEFT配置(如LoRA参数:r=64, alpha=16))# 若启用评估策略(如每轮评估),配置生成参数并添加回调if training_args.eval_strategy != "no":# 定义生成配置(控制模型生成行为)generation_config = GenerationConfig(max_new_tokens=training_args.max_new_tokens, # 最大生成长度do_sample=True, # 启用采样(而非贪心生成)temperature=training_args.temperature # 温度参数(控制生成多样性,值越大越随机))# 初始化回调函数:记录评估时的生成结果(如保存8个示例)completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)# 向训练器添加回调trainer.add_callback(completions_callback)# 启动训练trainer.train()# 保存模型到输出目录trainer.save_model(training_args.output_dir)# 若启用push_to_hub,将模型推送到Hugging Face Hubif training_args.push_to_hub:trainer.push_to_hub(dataset_name=script_args.dataset_name)
GKDTrainer
import os
import random
import textwrap
from typing import Any, Callable, Optional, Unionimport torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import Dataset
from transformers import (AutoModelForCausalLM, # 用于加载因果语言模型(如GPT类模型)BaseImageProcessor, # 图像处理基类(此处未直接使用,为兼容多模态预留)DataCollator, # 数据整理器,用于批量处理数据FeatureExtractionMixin, # 特征提取混入类(兼容多模态)GenerationConfig, # 生成配置,控制模型生成行为(如长度、温度等)PreTrainedModel, # 预训练模型基类PreTrainedTokenizerBase, # 预训练分词器基类ProcessorMixin, # 处理器混入类(兼容多模态处理器)is_wandb_available, # 检查是否安装wandb(实验跟踪工具)
)
from transformers.trainer_callback import TrainerCallback # 训练回调基类
from transformers.trainer_utils import EvalPrediction # 评估预测结果格式
from transformers.utils import is_peft_available # 检查是否安装PEFT(参数高效微调工具)from ..models import prepare_deepspeed # 准备Deepspeed配置(分布式训练)
from ..models.utils import unwrap_model_for_generation # 为生成任务解包模型(如处理PEFT包装)
from .gkd_config import GKDConfig # GKD训练的核心配置类
from .sft_trainer import SFTTrainer # 监督微调训练器(GKDTrainer的父类)
from .utils import (DataCollatorForChatML, # 针对ChatML格式的数据集整理器disable_dropout_in_model, # 禁用模型中的dropout层(稳定训练)empty_cache, # 清空GPU缓存(节省显存)generate_model_card, # 生成模型卡片(README.md)get_comet_experiment_url, # 获取Comet实验跟踪URL(若使用)
)# 条件导入:仅当PEFT库可用时导入PeftConfig
if is_peft_available():from peft import PeftConfig# 条件导入:仅当wandb可用时导入wandb(实验跟踪)
if is_wandb_available():import wandbclass GKDTrainer(SFTTrainer):"""广义知识蒸馏(Generalized Knowledge Distillation)训练器,继承自监督微调训练器(SFTTrainer)。核心功能:通过教师模型指导学生模型训练,结合动态生成样本(on-policy学习)和广义JSD损失,提升小模型性能。"""_tag_names = ["trl", "gkd"] # 模型卡片标签def __init__(self,model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, # 学生模型(可传入路径或实例)teacher_model: Union[PreTrainedModel, nn.Module, str] = None, # 教师模型(可传入路径或实例)args: Optional[GKDConfig] = None, # GKD训练配置(含蒸馏参数、训练超参等)data_collator: Optional[DataCollator] = None, # 数据整理器(默认为ChatML格式)train_dataset: Optional[Dataset] = None, # 训练数据集eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, # 评估数据集processing_class: Optional[Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]] = None, # 数据处理器(通常为分词器)compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, # 评估指标计算函数callbacks: Optional[list[TrainerCallback]] = None, # 训练回调(如日志、早停等)optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), # 优化器和学习率调度器preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, # 处理logits用于计算指标peft_config: Optional["PeftConfig"] = None, # PEFT配置(如LoRA参数)formatting_func: Optional[Callable] = None, # 数据格式化函数(将样本转为模型输入格式)):# 禁用自动移除未使用的列(因GKD需要"prompts"等额外字段)args.remove_unused_columns = False# 初始化数据整理器:使用ChatML格式(适合对话模型),限制最大长度data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)# 调用父类(SFTTrainer)的初始化方法,完成基础训练器配置super().__init__(model,args=args,data_collator=data_collator,train_dataset=train_dataset,eval_dataset=eval_dataset,processing_class=processing_class,compute_metrics=compute_metrics,callbacks=callbacks,optimizers=optimizers,preprocess_logits_for_metrics=preprocess_logits_for_metrics,peft_config=peft_config,formatting_func=formatting_func,)# 处理教师模型的初始化参数if args.teacher_model_init_kwargs is None:teacher_model_init_kwargs = {} # 无参数时使用空字典elif not isinstance(teacher_model, str):# 若教师模型已实例化,则不允许传入初始化参数(避免冲突)raise ValueError("已传入实例化的teacher_model,但同时指定了teacher_model_init_kwargs,两者冲突。")else:teacher_model_init_kwargs = args.teacher_model_init_kwargs# 处理数据类型参数(将字符串转为torch dtype,如"float16"→torch.float16)teacher_model_init_kwargs["torch_dtype"] = (teacher_model_init_kwargs["torch_dtype"]if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]else getattr(torch, teacher_model_init_kwargs["torch_dtype"]))# 若教师模型是路径字符串,则加载预训练模型if isinstance(teacher_model, str):teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)# 禁用学生模型的dropout层(稳定蒸馏过程,减少随机性)if args.disable_dropout:disable_dropout_in_model(self.model)# 准备教师模型:若启用Deepspeed(分布式训练),则适配Deepspeed;否则用accelerator准备if self.is_deepspeed_enabled:self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)else:# 将教师模型设为评估模式(不训练,仅用于推理)self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)# 初始化GKD核心超参数self.lmbda = args.lmbda # 学生自生成样本的概率(on-policy学习概率)self.beta = args.beta # 广义JSD损失的插值系数self.temperature = args.temperature # 概率分布的温度系数(控制平滑度)self.seq_kd = args.seq_kd # 是否强制使用教师生成的序列进行蒸馏# 初始化生成配置(控制动态样本生成的行为)self.generation_config = GenerationConfig(max_new_tokens=args.max_new_tokens, # 最大生成token数temperature=args.temperature, # 生成温度(值越大越随机)do_sample=True, # 启用采样(而非贪心生成)top_k=0, # 不限制top_k(配合温度控制多样性)use_cache=False if args.gradient_checkpointing else True, # 梯度检查点启用时禁用缓存(冲突)pad_token_id=self.processing_class.pad_token_id, # 填充token ID)# 适配模型自定义的EOS token(如Llama 3的<|eot_id|>)if (hasattr(self.model.generation_config, "eos_token_id")and self.model.generation_config.eos_token_id is not None):self.generation_config.eos_token_id = self.model.generation_config.eos_token_id@staticmethoddef generalized_jsd_loss(student_logits, # 学生模型的logits,形状:(batch_size, seq_len, vocab_size)teacher_logits, # 教师模型的logits,形状同上labels=None, # 标签,形状:(batch_size, seq_len),-100表示padding(忽略)beta=0.5, # 插值系数(控制教师/学生分布权重)temperature=1.0, # 温度系数(软化概率分布)reduction="batchmean", # 损失聚合方式(batchmean/sum/mean)):"""计算广义Jensen-Shannon散度(JSD)损失,用于知识蒸馏。参考论文:https://huggingface.co/papers/2306.13649 公式(1)"""# 温度缩放:软化概率分布(温度越高,分布越平滑)student_logits = student_logits / temperatureteacher_logits = teacher_logits / temperature# 计算学生和教师的对数概率(log_softmax)student_log_probs = F.log_softmax(student_logits, dim=-1)teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)if beta == 0:# beta=0:退化为传统KL散度(学生模仿教师)# F.kl_div(input=学生对数概率, target=教师对数概率, log_target=True)jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)elif beta == 1:# beta=1:反向KL散度(教师模仿学生,适合学生容量较小时避免模式崩溃)jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)else:# 混合分布的对数概率:log[(1-beta)*P_student + beta*P_teacher]# 等价于log(exp(log(1-beta) + log_P_student) + exp(log(beta) + log_P_teacher))beta = torch.tensor(beta, dtype=student_log_probs.dtype) # 转为tensor(匹配设备和类型)mixture_log_probs = torch.logsumexp(torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),dim=0, # 按第0维(学生/教师)求和)# 计算混合分布与教师/学生分布的KL散度(注意PyTorch的KL顺序与数学定义相反)kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)# 广义JSD:beta*KL(混合||教师) + (1-beta)*KL(混合||学生)jsd = beta * kl_teacher + (1 - beta) * kl_student# 掩码处理:忽略padding位置(labels=-100)的损失if labels is not None:mask = labels != -100 # 有效位置为True,padding为Falsejsd = jsd[mask] # 只保留有效位置的损失# 损失聚合(根据reduction参数)if reduction == "batchmean":# 按有效样本数平均(避免padding影响)if labels is not None:return jsd.sum() / mask.sum() # 总损失 / 有效token数else:return jsd.sum() / (jsd.size(0) * jsd.size(1)) # 总损失 / 总token数(无标签时)elif reduction == "sum":return jsd.sum() # 求和elif reduction == "mean":return jsd.mean() # 简单平均else:return jsd # 不聚合,返回原始损失 tensordef compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):"""计算GKD的损失:通过广义JSD损失让学生模仿教师的输出分布。"""# 学生模型前向传播(获取logits)outputs_student = model(input_ids=inputs["input_ids"], # 输入token IDattention_mask=inputs["attention_mask"], # 注意力掩码(0表示padding))# 教师模型前向传播(评估模式,不计算梯度)self.teacher_model.eval() # 确保教师模型处于评估模式(禁用dropout等)with torch.no_grad(): # 禁用梯度计算(节省显存)outputs_teacher = self.teacher_model(input_ids=inputs["input_ids"],attention_mask=inputs["attention_mask"],)# 切片处理:只保留生成部分的logits(排除输入prompt部分)prompt_lengths = inputs["prompts"].shape[1] # prompt的长度(输入部分,无需预测)# 学生logits:从prompt结束位置的前一个token开始,到序列结束前一个token(因语言模型预测下一个token)shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]# 教师logits:同上(与学生对齐)shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]# 标签:从prompt结束位置开始(生成部分的真实标签)shifted_labels = inputs["labels"][:, prompt_lengths:]# 计算广义JSD损失loss = self.generalized_jsd_loss(student_logits=shifted_student_logits,teacher_logits=shifted_teacher_logits,labels=shifted_labels, # 用于掩码paddingbeta=self.beta, # 从初始化参数获取)# 清空GPU缓存(节省显存)empty_cache()# 返回损失(可选返回学生模型输出)return (loss, outputs_student) if return_outputs else loss@staticmethoddef generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):"""生成动态样本(on-policy学习用):基于输入prompt生成输出序列,作为新的训练数据。"""# 基于prompt生成输出(仅用prompt作为输入,不包含原始标签)generated_outputs = model.generate(input_ids=inputs["prompts"], # 输入prompt的token IDattention_mask=inputs.get("prompt_attention_mask", None), # prompt的注意力掩码generation_config=generation_config, # 生成配置(长度、温度等)return_dict_in_generate=True, # 返回详细生成结果(含序列、分数等))# 获取生成的token ID序列generated_tokens = generated_outputs.sequences# 初始化新的注意力掩码(全1,后续修正padding位置)new_attention_mask = torch.ones_like(generated_tokens)# 新标签:复制生成的token(后续修正padding位置)new_labels = generated_tokens.clone()# 处理padding token(若指定)if pad_token_id is not None:# 标签中padding位置设为-100(忽略损失)new_labels[new_labels == pad_token_id] = -100# 注意力掩码中padding位置设为0(不参与注意力计算)new_attention_mask[generated_tokens == pad_token_id] = 0# 返回生成的输入ID、注意力掩码、标签return generated_tokens, new_attention_mask, new_labelsdef training_step(self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None) -> torch.Tensor:"""单步训练:实现on-policy学习,动态生成样本用于训练。逻辑:以概率lmbda使用学生自生成样本,或强制使用教师生成样本(seq_kd=True)。"""if self.seq_kd:# seq_kd=True:强制使用教师模型生成样本(适合初始训练阶段,学习教师的"正确"输出)# 解包教师模型(处理PEFT/分布式包装)with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:# 生成教师样本new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)# 更新输入为教师生成的样本inputs["input_ids"] = new_input_idsinputs["attention_mask"] = new_attention_maskinputs["labels"] = new_labels# 以概率lmbda使用学生自生成样本(on-policy学习核心)if random.random() <= self.lmbda:# 解包学生模型(处理PEFT/分布式包装)with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:# 生成学生样本new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id)# 更新输入为学生生成的样本(让学生从自身生成的结果中学习)inputs["input_ids"] = new_input_idsinputs["attention_mask"] = new_attention_maskinputs["labels"] = new_labels# 调用父类的training_step计算损失并更新参数loss = super().training_step(model, inputs, num_items_in_batch)return lossdef create_model_card(self,model_name: Optional[str] = None, # 模型名称dataset_name: Optional[str] = None, # 训练数据集名称tags: Union[str, list[str], None] = None, # 模型标签):"""生成模型卡片(README.md),包含训练信息、引用、标签等,方便上传到Hugging Face Hub。"""# 仅在主进程执行(避免多进程重复生成)if not self.is_world_process_zero():return# 确定基座模型名称(若模型从预训练模型微调而来)if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):base_model = self.model.config._name_or_pathelse:base_model = None# 标准化标签(转为集合避免重复)if tags is None:tags = set()elif isinstance(tags, str):tags = {tags}else:tags = set(tags)# 若使用unsloth加速训练,添加对应标签if hasattr(self.model.config, "unsloth_version"):tags.add("unsloth")# 添加默认标签(trl和gkd)tags.update(self._tag_names)# GKD论文引用格式citation = textwrap.dedent("""\@inproceedings{agarwal2024on-policy,title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},year = 2024,booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},publisher = {OpenReview.net},url = {https://openreview.net/forum?id=3zKtaqxLhW},}""")# 生成模型卡片内容model_card = generate_model_card(base_model=base_model, # 基座模型model_name=model_name, # 模型名称hub_model_id=self.hub_model_id, # Hub上的模型IDdataset_name=dataset_name, # 训练数据集tags=tags, # 标签wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, # wandb实验URLcomet_url=get_comet_experiment_url(), # Comet实验URLtrainer_name="GKD", # 训练器名称trainer_citation=citation, # 引用paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", # 论文标题paper_id="2306.13649", # 论文ID(arXiv或OpenReview))# 保存模型卡片到输出目录model_card.save(os.path.join(self.args.output_dir, "README.md"))
GKDConfig
from dataclasses import dataclass, field
from typing import Any, Optionalfrom transformers import TrainingArguments # 导入Hugging Face的训练参数基类from .sft_config import SFTConfig # 导入监督微调(SFT)的配置类(GKDConfig的父类)@dataclass
class GKDConfig(SFTConfig):"""广义知识蒸馏(Generalized Knowledge Distillation, GKD)训练器的配置类。此类仅包含GKD训练特有的参数,完整的训练参数请参考`transformers.TrainingArguments`和`SFTConfig`的文档。参数说明:temperature (`float`, 可选, 默认值 `0.9`):采样温度。温度越高,生成的结果随机性越强。lmbda (`float`, 可选, 默认值 `0.5`):控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。beta (`float`, 可选, 默认值 `0.5`):广义Jensen-Shannon散度(JSD)损失的插值系数,范围在`0.0`到`1.0`之间。当beta=0.0时,损失退化为传统KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。max_new_tokens (`int`, 可选, 默认值 `128`):每次生成的最大token数量。teacher_model_name_or_path (`str` 或 `None`, 可选, 默认值 `None`):教师模型的名称或路径。若为`None`,则教师模型与当前训练的模型相同。teacher_model_init_kwargs (`dict[str, Any]` 或 `None`, 可选, 默认值 `None`):从字符串实例化教师模型时,传递给`AutoModelForCausalLM.from_pretrained`的关键字参数。disable_dropout (`bool`, 可选, 默认值 `True`):是否禁用模型中的dropout层(蒸馏中常用,以减少随机性,稳定训练)。seq_kd (`bool`, 可选, 默认值 `False`):是否执行序列级蒸馏(Sequence-Level KD),可视为在教师生成的输出上进行监督微调。"""# 扩展有效字典字段:在TrainingArguments的基础上添加教师模型的初始化参数_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]temperature: float = field(default=0.9,metadata={"help": "采样温度。温度越高,生成的结果随机性越强。"},)lmbda: float = field(default=0.5,metadata={"help": "控制学生自生成样本比例的参数(即on-policy学习中,使用学生自己生成的输出进行训练的比例)。"},)beta: float = field(default=0.5,metadata={"help": "广义Jensen-Shannon散度(JSD)损失的插值系数,范围在0.0到1.0之间。""当beta=0.0时,损失为KL散度(学生模仿教师);当beta=1.0时,损失为反向KL散度(教师模仿学生)。"},)max_new_tokens: int = field(default=128,metadata={"help": "每次生成的最大token数量。"},)teacher_model_name_or_path: Optional[str] = field(default=None,metadata={"help": "教师模型的名称或路径。若为None,教师模型将与当前训练的模型相同。"},)teacher_model_init_kwargs: Optional[dict[str, Any]] = field(default=None,metadata={"help": "从字符串实例化教师模型时,传递给AutoModelForCausalLM.from_pretrained的关键字参数。"},)disable_dropout: bool = field(default=True,metadata={"help": "是否禁用模型中的dropout层(蒸馏中常用以稳定训练)。"},)seq_kd: bool = field(default=False,metadata={"help": "是否执行序列级蒸馏(可视为在教师生成的输出上进行监督微调)。"},)def __post_init__(self):"""初始化后执行的方法:调用父类初始化逻辑,并验证参数合法性。"""super().__post_init__() # 调用父类(SFTConfig)的初始化后处理逻辑# 验证lmbda参数是否在[0, 1]范围内if self.lmbda < 0.0 or self.lmbda > 1.0:raise ValueError("lmbda参数必须在[0.0, 1.0]范围内。")# 验证beta参数是否在[0, 1]范围内if self.beta < 0.0 or self.beta > 1.0:raise ValueError("beta参数必须在[0.0, 1.0]范围内。")