基于 LoRA 和 GRPO 的 Qwen2.5-3B 数学推理模型微调示例
一、项目概述
本项目基于Qwen2.5-3B-Instruct模型,结合LoRA(低秩自适应)和GRPO技术,针对数学推理任务(GSM8K数据集)进行微调,旨在训练一个能以XML格式输出链式思考和答案的数学助理模型。通过多维度奖励函数引导模型生成符合格式要求且答案正确的响应。
二、关键技术与实现细节
1. 数据处理
- 数据集:使用GSM8K数学推理数据集的训练集,包含问题(question)和带推导过程的答案(answer)。
- 预处理:
- 通过
preprocess
函数将问题包装为包含系统提示的对话格式(System Prompt + User Question)。 - 从答案中提取最终数值结果(通过正则匹配
####
或\boxed{}
格式),用于后续奖励函数的正确性验证。
- 通过
2. 模型与LoRA配置
- 基础模型:加载Qwen2.5-3B-Instruct模型,启用梯度检查点(
gradient_checkpointing_enable
)以减少显存占用。 - LoRA参数:
- 秩(r)=16,alpha=32,dropout=0.05,仅优化注意力和前馈网络的关键模块(如
q_proj
、k_proj
等)。 - 可训练参数:目标模块数量 × 秩® × (输入维度 + 输出维度),显著降低微调成本。
- 秩(r)=16,alpha=32,dropout=0.05,仅优化注意力和前馈网络的关键模块(如
3. 奖励函数设计(GRPO核心)
- 正确性奖励:检查生成响应的XML答案是否包含正确数值,正确则奖励1.0,否则0.0。
- 格式奖励:
- 软检查:匹配是否包含
<reasoning>
和<answer>
结构,符合则奖励2.0。 - 严格检查:要求响应严格以XML格式开头和结尾,无多余内容,符合则奖励4.0。
- 软检查:匹配是否包含
- 多奖励组合:通过线性叠加引导模型同时满足格式规范性和答案准确性。
4. GRPO训练配置
- 优化参数:
- 学习率2e-4,余弦调度器(warmup_ratio=0.05),梯度累积8步,批次大小16(单卡)。
- 最大提示长度512,生成长度64,每次采样8个候选响应进行优化。
- 自定义训练器:继承GRPOTrainer,使用AdamW优化器和余弦学习率调度,适配LoRA训练流程。
三. 程序如下:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
from torch.optim import AdamW
from peft import LoraConfig, get_peft_model
from modelscope import snapshot_download
from trl import GRPOConfig, GRPOTrainer
import re
import torch# ========== 数据加载与预处理 ==========SYSTEM_PROMPT = "你是一个擅长用 XML 格式输出链式思考和答案的数学助理,使用这种格式进行回答<reasoning>...</reasoning>\n<answer>...</answer>"def extract_final_answer(text: str) -> str:match = re.search(r'####\s*(\d+)', text)if match:return match.group(1).strip()match = re.search(r'\\boxed\{(.*?)\}', text)if match:return match.group(1).strip()return ""def preprocess(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["question"]},],"answer": extract_final_answer(example["answer"]),}raw = load_dataset("gsm8k", "main", split="train")
dataset = raw.map(preprocess, remove_columns=raw.column_names)# ========== 加载模型与 LoRA 配置 ==========MODEL_ID = "Qwen/Qwen2.5-3B-Instruct"
OUTPUT_DIR = "outputs/qwen3b-grpo-lora-fp16"
DEVICE_ID = 0
torch.cuda.set_device(DEVICE_ID)local_path = snapshot_download(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_tokenbase_model = AutoModelForCausalLM.from_pretrained(local_path,torch_dtype=torch.float16,device_map="auto",trust_remote_code=True,
)
base_model.gradient_checkpointing_enable()lora_cfg = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],
)model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()# ========== 定义奖励函数 ==========def extract_xml_answer(text: str) -> str:match = re.search('<answer>(.*)</answer>', text, re.DOTALL)return match.group(1).strip() if match else ""def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:responses = [completion[0]['content'] for completion in completions]extracted = [extract_xml_answer(r) for r in responses]return [1.0 if a in r else 0.0 for r, a in zip(extracted, answer)]def soft_format_reward_func(completions, **kwargs) -> list[float]:pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"responses = [c[0]["content"] for c in completions]return [2.0 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]def strict_format_reward_func(completions, **kwargs) -> list[float]:pattern = r"^\s*<reasoning>.*?</reasoning>\s*<answer>.*?</answer>\s*$"responses = [c[0]["content"] for c in completions]return [4.0 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]# ========== 定义 GRPO 配置 ==========train_args = GRPOConfig(fp16=True,per_device_train_batch_size=16,gradient_accumulation_steps=8,learning_rate=2e-4,num_train_epochs=1,lr_scheduler_type="cosine",warmup_ratio=0.05,max_grad_norm=0.3,logging_steps=1,save_steps=100,output_dir=OUTPUT_DIR,report_to="tensorboard",max_prompt_length=512,max_completion_length=64,num_generations=8,use_vllm=False,
)# ========== 自定义 GRPOTrainer:添加 AdamW 与 Cosine 调度器 ==========class CustomGRPOTrainer(GRPOTrainer):
# num_training_steps = (
# len(dataset) // (args.per_device_train_batch_size * args.gradient_accumulation_steps)
# ) * args.num_train_epochsdef create_optimizer_and_scheduler(self, num_training_steps: int):self.optimizer = AdamW(self.model.parameters(), lr=self.args.learning_rate)self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer,num_warmup_steps=int(self.args.warmup_ratio * num_training_steps),num_training_steps=num_training_steps)# ========== 启动训练 ==========trainer = CustomGRPOTrainer(model=model,processing_class=tokenizer,reward_funcs=[soft_format_reward_func,strict_format_reward_func,correctness_reward_func],args=train_args,train_dataset=dataset,
)
trainer.train()# ========== 模型保存 ==========model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"✔ LoRA Adapter + Tokenizer 已保存到 {OUTPUT_DIR}")