当前位置: 首页 > news >正文

强化微调:以Swift框架进行GRPO多模态模型强化微调为例

一、TL;DR

  1. 整体介绍:强化微调RFT的原因、步骤、作用以及常见的rft方式
  2. dmeo举例:以Swift给的Qwen2.5-Math-7B-Instruct为例介绍了整个RFT的流程和代码细节
  3. 实际强化微调:以qwen/internVL为例完成一次指令微调并且使用强化学习进一步提升指标

二、整体介绍

2.1 为什么要做强化微调

掉点/回退现象:

基础MLLM经过含有CoT训练集上做SFT后,发现在test集上掉点,可以通过强化微调来确保不会发生这种情况

  1. 在LLaMA3上,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。

原因:

模型的知识遗忘,举例如下:

  1. 正常流程:在微调的时候会加入非常多的CoT数据集
  2. 造成结果:在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。
  3. 原因分析:当模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,

2.2 什么时候可以使用强化微调

当有如下条件之一时使用强化微调:

  1. 已经微调过模型,能力不满足需求
  2. 需要更强的CoT能力
  3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
  4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等

强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。

2.3 强化微调的步骤

2.3.1 使用某个模型生成数据/进行原始数据扩充然后采样

  1. 大模型生成数据:使用GPT、Qwen-Max、DeepSeek-V3/R1等生成和扩充数据,则该强化微调可以理解为蒸馏

  2. 模型本身生成数据:可以理解为自我提升(self-improvement)微调
  3. 采样过程-on-policy算法:采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环
  4. 采样算法:包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
  5. 采样过程额外引入细节:可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等

2.3.2 使用数据训练目标模型

训练的方式:

  1. 如果使用SFT,则称为拒绝采样微调
  2. 如果是强化学习,则称为强化学习微调

2.3.3 根据需要判断是否重复上述过程

  1. 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环

  2. 如果使用本模型进行采样,或者PPO等算法,则会有循环

2.4 常见的强化微调方式

  1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
  2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
  3. on-policy RL:使用PPO、GRPO等方式循环训练

2.5 ms-swift的展示demo

SFT和RFT的区别:

使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。

同样可以发现,Qwen2.5这个模型经过RFT后在原有的其他数据集gsm8k上也没有出现大幅度回退(这就是为什么比SFT好的原因,新数据集上有效果,旧数据集上不坍塌)。

参考资料:强化微调 — swift 3.8.0.dev0 文档

三、demo代码分析

3.1 main函数分析

遵循第二节的流程:

  1. 先采样;
  2. 再做RLT
  3. 再做循环-5次

注意:以上这些流程都是使用python拼接输入命令行,不是一个函数就搞定了所有的代码哈,核心的这些命令行的功能都被swift封装在框架里面了,尤其是PRM模型的选取这些。

3.2 do-sample采样函数

如下图所示,过程奖励模型使用了Qwen2.5-Math-PRM-7B模型,为每一块GPU上生成了一个采样的RFT数据集

PRM模型和PRM_threshold如何配合形成采样数据集:

3.3 do_train训练函数

直接将rlhf的训练type写入启动脚本,开始强化微调:

代码参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/rft/rft.py

四、实际项目举例

闲下来再写吧 这个要记录自己的实验结果,我后续截图补充再写

http://www.xdnf.cn/news/1442719.html

相关文章:

  • 【明道云】[工作表控件5] 手机控件的格式化处理
  • 在麒麟 ARM (aarch64)安装OpenJDK11和elasticsearchkibana
  • 云手机中的三大核心技术主要是指什么?
  • Docker部署Lunalytics开源监控工具
  • 开源检索增强生成(UltraRAG)框架
  • Unity2018版本安卓打包环境配置问题
  • 搞定鸿蒙新手 3 大痛点:页面跳转实现、应用标识修改与 Hyper-V 启动故障排查
  • Elasticsearch(text和keyword)区别分析
  • 【教程】IDEA中导入springboot-maven工程
  • Git 别名:用简短命令大幅提升开发效率
  • 企业级AI应用,Dify集成RAGFlow知识库保姆教程
  • 少儿编程C++快速教程之——1. 基础语法和输入输出
  • 【STL源码剖析】从源码看 deque :拆解双端队列的底层实现与核心逻辑
  • 聚焦岗位能力提升:休闲服务与管理虚拟仿真实训室的实训设计与落地
  • 华为卫星对星引导技术深度解析:原理、实现与开源替代方案
  • 从 MMLU 到 HumanEval:为什么评估大型语言模型(LLM)的基准至关重要?
  • 计算机二级C语言操作题(填空、修改、设计题)——真题库(14)附解析答案
  • 医学图像配准的循环推理机|文献速递-深度学习人工智能医疗图像
  • Aerobits-用于 sUAS 和 UTM/U-Space 的微型 ADS-B 技术(收发器/接收器)和无人机跟踪应答器
  • 车载诊断架构 --- 从架构系统角度怎么确保整车DTC的完整性?
  • 蓝光三维扫描技术赋能内衣胸垫设计:从精准制造到个性化体验的革新之旅
  • 突破性能瓶颈:Scala爬虫的大规模数据处理方案
  • 【Lua】题目小练14
  • 为什么几行dropout就能显著提升稀疏3DGS渲染质量?
  • 深度学习篇---InceptionNet网络结构
  • 【串口助手】串口调试助手LTSerialTool v3.12.0发布
  • A股大盘数据-2025093分析
  • Java如何实现jar包方法覆盖
  • C语言字符函数和字符串函数(1)
  • TypeScript 与 Java 重载机制对比