当前位置: 首页 > ai >正文

基于 Python 的自然语言处理系列(82):Transformer Reinforcement Learning

🔗 本文所用工具:trltransformerspeftbitsandbytes
📘 官方文档参考:https://huggingface.co/docs/trl

一、引言:从有监督微调到 RLHF 全流程

        随着语言大模型的发展,如何在大规模预训练模型基础上更精细地对齐人类偏好,成为了研究与应用的热点。本文将介绍一套完整的 RLHF(Reinforcement Learning with Human Feedback)训练流程,基于 Hugging Face 推出的 trl 库,从 SFT(Supervised Fine-tuning)、RM(Reward Modeling)、到 PPO(Proximal Policy Optimization)三大阶段,逐步实现对 Transformer 模型的强化学习优化。

        本篇聚焦于 SFT 阶段的实现,并以 Hugging Face 提供的 instruction-dataset 为例,介绍如何使用 trl 和 PEFT(参数高效微调)技术训练一个高效对齐指令的语言模型。

二、安装与环境准备

        确保安装以下库(建议使用 PyTorch + CUDA 环境):

pip install trl transformers datasets peft bitsandbytes accelerate

三、加载并准备数据集

        本例使用 HuggingFaceH4 团队整理的 instruction-dataset

from datasets import load_datasetdataset = load_dataset("HuggingFaceH4/instruction-dataset")
dataset = dataset.remove_columns("meta")  # 移除无用字段
dataset

四、构建模型及量化配置(4-bit)

        使用 BitsAndBytesConfig 对模型进行 4-bit 量化,可大幅降低显存占用:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_trainingmodel_name = "lmsys/fastchat-t5-3b-v1.0"bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,quantization_config=bnb_config
)base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)

五、注入 LoRA 参数高效微调机制

        首先识别所有 4-bit 线性模块并定义 LoRA 参数配置:

import bitsandbytes as bnb
from peft import get_peft_model, LoraConfigdef find_all_linear_names(model):cls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split(".")lora_module_names.add(names[0] if len(names) == 1 else names[-1])return list(lora_module_names)peft_config = LoraConfig(r=128,lora_alpha=16,target_modules=find_all_linear_names(base_model),lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)base_model = get_peft_model(base_model, peft_config)

        打印可训练参数占比:

def print_trainable_parameters(model):trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)total = sum(p.numel() for p in model.parameters())print(f"Trainable params: {trainable} / {total} ({trainable / total:.2%})")print_trainable_parameters(base_model)

六、定义 Prompt 格式化函数

        将数据集中的 promptcompletion 格式化为统一格式:

def formatting_prompts_func(example):return [f"### Input: ```{prompt}```\n ### Output: {completion}"for prompt, completion in zip(example["prompt"], example["completion"])]

七、训练参数设置与 SFTTrainer 训练器

        使用 SFTTrainer 执行指令微调训练,支持 gradient checkpointing、cosine 学习率调度等高级策略:

from transformers import TrainingArguments
from trl import SFTTraineroutput_dir = "./results"training_args = TrainingArguments(output_dir=output_dir,per_device_train_batch_size=4,gradient_accumulation_steps=4,gradient_checkpointing=True,max_grad_norm=0.3,num_train_epochs=15,learning_rate=2e-4,bf16=True,save_total_limit=3,logging_steps=10,optim="paged_adamw_32bit",lr_scheduler_type="cosine",warmup_ratio=0.05,
)tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"trainer = SFTTrainer(model=base_model,train_dataset=dataset,tokenizer=tokenizer,max_seq_length=2048,formatting_func=formatting_prompts_func,args=training_args
)

        执行训练:

trainer.train()
trainer.save_model(output_dir)

        保存最终模型权重与 tokenizer:

import os
final_output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)

八、小结与展望

        通过本文,我们使用 trl 工具链完成了 RLHF 的第一阶段:SFT 有监督微调。你可以根据项目实际需求,替换为自定义数据集或更大规模模型。后续步骤(RM 奖励建模 + PPO 策略优化)将在下一篇继续介绍。

📌 下一篇预告

        📘《基于 Python 的自然语言处理系列(83):RLHF 全流程之 PPO 强化微调》

        敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

http://www.xdnf.cn/news/494.html

相关文章:

  • Alan AI - 面向Web的生成式AI SDK
  • 基于C语言实现文件读取
  • Linux 第五讲 --- 权限管理
  • 6.常用控件-QWidget|windowTitle|windowIcon|qrc机制|windowOpacity|cursor(C++)
  • Amlogic S905L3 系列对比:L3A、L3B 与 L3AB 深度解析
  • Unity之如何实现RenderStreaming视频推流
  • 大学英语四级选词填空阅读题和段落匹配解析
  • 【Hot100】54. 螺旋矩阵
  • 2025.04.19-阿里淘天春招算法岗笔试-第一题
  • 金融数学专题6 证券问题与资本利得税
  • Pandas数据统计分析
  • MCS-51单片机汇编语言编程指南
  • ArcPy Mapping 模块基础
  • 3. 进程概念
  • 修改Theme SHELL美化panel
  • Docker 网络详解:从 docker0 网桥到网络命名空间
  • 复习JUC的总结笔记
  • 整流二极管详解:原理、作用、应用与选型要点
  • 什么是零缺陷质量管理?
  • DNS主从同步实验
  • LeetCode 解题思路 42(Hot 100)
  • DDPM(diffusion)原理
  • 健康养生:拥抱美好生活的基石
  • LangChain框架-检索器详解
  • Map和Set相关练习
  • c++_csp-j算法 (2)
  • Vue中的template配置项
  • Kafka下载和使用(Windows版)
  • docker 大模型
  • 【数学】勾股定理