当前位置: 首页 > ds >正文

基于GRPO将QWEN训练为和deepseek一样的推理模型!

 GRPO

群体相对策略优化(GRPO)算法最初由deepseek团队提出,是近端策略优化(PPO)的一个变体。

GRPO 是一种在线学习算法,它通过使用训练过程中已训练模型自身生成的数据进行迭代改进。GRPO 目标背后的逻辑是在确保模型与参考策略保持接近的同时,最大化生成的completion的优势。

DeepSeek团队在使用纯强化学习 训练 R1-Zero 时观察到了一个“aha moment”。该模型学会了通过重新评估其初始方法来延长其思考时间,而无需任何人工指导或预定义指令。

  1. 该模型将生成多个响应

  2. 每个响应都根据正确性或由某奖励函数而不是 LLM 奖励模型创建的另一个指标进行评分

  3. 计算该组的平均分数

  4. 将每个响应的分数与组平均值进行比较

  5. 该模型经过强化,有利于得分较高的响应

例如,假设我们想要一个模型来求解:

1+1=?>> 思维链/解决 >> 答案是 2.                          

2+2=?>> 思维链/解决 >> 答案是 4.

以前必须收集大量数据来填充锻炼/思维链过程。但是GRPO可以引导模型自动展示推理功能并创建推理跟踪。

构建聊天模板

提示模型在提供答案之前阐明其推理。需要先将提示和响应建立一个明确的格式。

# system prompt
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

数据集

选择OpenAI的 GSM8K 数据集,其中包含小学数学问题,每个答案都会有一个推理过程,并在 “####”后面附上最终答案。(该数据集在诸多论文中出现过)


 

数据集是通过提取答案并将其格式化为结构化字符串来准备的 

# 准备数据集
import re
from datasets import load_dataset, Dataset# 提取推理过程
def extract_xml_answer(text: str) -> str:answer = text.split("<answer>")[-1]answer = answer.split("</answer>")[0]return answer.strip()# 提取最终答案
def extract_hash_answer(text: str) -> str | None:if "####" not in text:return Nonereturn text.split("####")[1].strip()# 准备GSM8K数据集
def get_gsm8k_questions(split="train") -> Dataset:data = load_dataset("openai/gsm8k", "main")[split]data = data.map(lambda x: {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": x["question"]},],"answer": extract_hash_answer(x["answer"]),})return datadataset = get_gsm8k_questions()

奖励函数

func1,依据模型生成的回复与标准答案的匹配状况来给出奖励分数。若回复和标准答案一致,就给予 2.0 的奖励;若不一致,则给予 0.0 的奖励

# 奖励标签完全匹配
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:responses = [completion[0]['content'] for completion in completions]q = prompts[0][-1]['content']extracted_responses = [extract_xml_answer(r) for r in responses]print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

func2,根据模型生成的回复是否为整数来给予奖励分数。若回复是一个整数形式的字符串,就给予 0.5 的奖励;若不是,则给予 0.0 的奖励

# 鼓励仅使用整数的答案
def int_reward_func(completions, **kwargs) -> list[float]:responses = [completion[0]['content'] for completion in completions]extracted_responses = [extract_xml_answer(r) for r in responses]return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

func3,检查模型生成的回复是否符合特定的格式要求。如果回复符合格式要求,则给予 0.5 的奖励;如果不符合,则给予 0.0 的奖励。

#  确保响应结构与提示匹配,包括换行符
def strict_format_reward_func(completions, **kwargs) -> list[float]:"""Reward function that checks if the completion has a specific format."""pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"responses = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, r) for r in responses]return [0.5 if match else 0.0 for match in matches]

func4,检查模型生成的回复是否符合特定的格式要求。若回复符合格式,就给予 0.5 的奖励;若不符合,则给予 0.0 的奖励。与之前的 strict_format_reward_func 相比,它的格式匹配要求没那么严格。

# 检查结构,但允许轻微的换行符不匹配
def soft_format_reward_func(completions, **kwargs) -> list[float]:"""Reward function that checks if the completion has a specific format."""pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"responses = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, r) for r in responses]return [0.5 if match else 0.0 for match in matches]

func5,检查文本中 <reasoning></reasoning><answer> 和 </answer> 存在与否,并根据其出现情况进行控分;同时根据 </answer> 之后的文本长度扣分

# 确保响应中的每个 XML 标签中只有一个,同时根据长度扣分
def count_xml(text) -> float:count = 0.0if text.count("<reasoning>\n") == 1:count += 0.125if text.count("\n</reasoning>\n") == 1:count += 0.125if text.count("\n<answer>\n") == 1:count += 0.125count -= len(text.split("\n</answer>\n")[-1])*0.001if text.count("\n</answer>") == 1:count += 0.125count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001return countdef xmlcount_reward_func(completions, **kwargs) -> list[float]:contents = [completion[0]["content"] for completion in completions]return [count_xml(c) for c in contents]

加载模型

选择Qwen-2.5-3B-Instruct模型,非推理模型

from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slowermodel, tokenizer = FastLanguageModel.from_pretrained(model_name = "Qwen/Qwen2.5-3B-Instruct",max_seq_length = max_seq_length,load_in_4bit = True, # False for LoRA 16bitfast_inference = True, # Enable vLLM fast inferencemax_lora_rank = lora_rank,gpu_memory_utilization = 0.5, # Reduce if out of memory
)

 LORA配置

model = FastLanguageModel.get_peft_model(model,r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128target_modules = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",], # Remove QKVO if out of memorylora_alpha = lora_rank,use_gradient_checkpointing = "unsloth", # Enable long context finetuningrandom_state = 3407,
)

GRPO配置

Transformer 强化学习(TRL)是一个全栈库,提供了一套工具,可使用诸如监督微调(SFT)、组相对策略优化(GRPO)、直接偏好优化(DPO)、奖励建模等方法来训练基于 Transformer 的语言模型。TRL已与Transformers库集成

use_vllm参数需要使用vllm库,vllm用于加速模型推理,其官方宣称吞吐量比HF的transformer 高 24 倍 ,无需更改任何模型架构。

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(use_vllm = True, # use vLLM for fast inference!learning_rate = 5e-6,adam_beta1 = 0.9,adam_beta2 = 0.99,weight_decay = 0.1,warmup_ratio = 0.1,lr_scheduler_type = "cosine",optim = "adamw_8bit",logging_steps = 1,bf16 = is_bfloat16_supported(),fp16 = not is_bfloat16_supported(),per_device_train_batch_size = 1,gradient_accumulation_steps = 1, # Increase to 4 for smoother trainingnum_generations = 8, # Decrease if out of memorymax_prompt_length = 256,max_completion_length = 200,# num_train_epochs = 1, # Set to 1 for a full training runmax_steps = 250,save_steps = 250,max_grad_norm = 0.1,report_to = "none", # Can use Weights & Biasesoutput_dir = "outputs",
)

开始训练

一块Tesla T4 GPU的训练耗时为1h出头(博主实测)。

trainer = GRPOTrainer(model = model,processing_class = tokenizer,reward_funcs = [xmlcount_reward_func,soft_format_reward_func,strict_format_reward_func,int_reward_func,correctness_reward_func,],args = training_args,train_dataset = dataset,
)
trainer.train()

测试评估

先加载之前的模型

text = tokenizer.apply_chat_template([{"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)from vllm import SamplingParams
sampling_params = SamplingParams(temperature = 0.8,top_p = 0.95,max_tokens = 1024,
)
output = model.fast_generate([text],sampling_params = sampling_params,lora_request = None,
)[0].outputs[0].text

在加载训练后的模型

# 
model.save_lora("grpo_saved_lora")text = tokenizer.apply_chat_template([{"role" : "system", "content" : SYSTEM_PROMPT},{"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)from vllm import SamplingParams
sampling_params = SamplingParams(temperature = 0.8,top_p = 0.95,max_tokens = 1024,
)
output = model.fast_generate(text,sampling_params = sampling_params,lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].textprint(output)

http://www.xdnf.cn/news/453.html

相关文章:

  • linux kernel irq相关函数详解
  • 国产的 Java Solon v3.2.0 发布(央企信创的优选)
  • Day10【基于encoder- decoder架构实现新闻文本摘要的提取】
  • 第3章 垃圾收集器与内存分配策略《深入理解Java虚拟机:JVM高级特性与最佳实践(第3版)》
  • 从0开发一个unibest+vue3项目,使用vscode编辑器开发,总结vue2升vue3项目开始,小白前期遇到的问题
  • Flask应用部署通用指南
  • [Windows] Wireshark 网络抓包工具 v4.4.6
  • 【MATLAB海洋专题】历史汇总
  • C++学习之路,从0到精通的征途:vector类的模拟实现
  • 算法期末复习
  • 基于Fabric.js的选座布局系统开发笔记
  • 如何系统地入门学习stm32?
  • Linux419 三次握手四次挥手抓包 wireshark
  • python程序的流程
  • 移动自动化测试-appium
  • thanos rule组件和prometheus区别?
  • 5G基站设计难题:尺寸、重量、功耗和散热
  • HashMap 初步理解 put 操作流程 HashMap 的线程安全问题
  • 精益数据分析(5/126):解锁创业成功的关键密码
  • FPGA练习———DDS波形发生器
  • 一个项目中多个Composer的使用方法
  • 读文献方法
  • Linux系统之diff3命令详解
  • Nodejs数据库单一连接模式和连接池模式的概述及写法
  • 2024-04-19| Java: Documented注解学习 JavaDoc
  • 【网络篇】TCP vs UDP底层区别+网络编程概念
  • 网络爬虫和前端相关知识
  • 根据Exif信息纠正图片方向
  • MYSQL初阶(暂为自用草稿)
  • 哔哩哔哩工具箱:BiliTools v1.3.2