【慢摹】TRL训练器细节详解(SFT/PPO/DPO/GRPO)
序言
本文总结一下目前TRL典型的训练器的实现细节(SFT,PPO,DPO,GRPO),也是对上一文【速写】PPOTrainer样例与错误思考(少量DAPO)的补充
目前DeepSeek关于各个训练器细节的掌握,尤其是PPOTrainer
的问题,依然回答得很差,这个在上文中已经详细指出,它写的代码大多数都是跑不通的,而官方给出的PPO示例(ppo.py)似乎也有一些瑕疵,本文将会一一指出与各位探讨。
目前笔者写了一个对于4个典型训练器(SFT,PPO,DPO,GRPO)都适用的一个base_pipeline
,以及对应的单元测试模块,其中PPOTrainer
是最为繁琐的一个训练器,它需要更多的模型(奖励模型,价值模型,参考模型),并且对应的训练数据集train_dataset
的处理方式与其他训练器存在显著区别。根据这份base_pipeline
的代码,我将逐一探讨TRL训练器中的细节问题,谨以抛砖引玉,欢迎探讨。
文章目录
- 序言
- 1 TRL通用的训练器Pipeline测试及其细节说明
- 1.1 训练配置参数的设置(Config)
- 1.2 加载分词器与模型(tokenizer & model)
- 1.3 数据集分割与字段问题(dataset)
- 1.4 训练与保存(checkpoint)
- 1.5 单元测试
- 2 一些其他的问题
- 2.1 DataProcessor与DataCollator
- 2.1.1 `dataset.map(prompt_formatter)`:
- 2.1.2 `data_collator`:
- 2.1.3 何时选择哪种方式?
- 2.2 关于PartialState
- 2.2.1 核心功能
- 2.2.2 典型使用场景
- 2.2.3 参数与底层机制
- 2.2.4 与类似方法的区别
- 2.2.5 完整示例(分布式训练中的数据加载)
- 2.2.6 注意事项
- 2.2.7 总结
- 2.3 DPO和PPO的reference_model的区别
- 2.3.1 DPO的核心思想与Reference Model的作用
- 2.3.2 **为什么DPO需要Reference Model,尽管它不是RL?**
- 2.3.3 TRL的DPOTrainer中的`ref_model`
- 2.3.4 **DPO vs PPO的Reference Model**
- 2.3.5 **如果没有Reference Model会怎样?**
1 TRL通用的训练器Pipeline测试及其细节说明
完整的项目在GitHub@caoyang-sufe/easyllm,以下我们先来看下面代码中的base_pipeline
,其他4个pipeline(sft_pipeline
,ppo_pipeline
,dpo_pipeline
,grpo_pipeline
)都是直接调用它使用。
对应的脚本在trainer.py
这个pipeline很重要,接下来所有的讨论都是围绕它逐行展开的。
# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport wandb
import logging
from copy import deepcopy
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, HfArgumentParser
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import (ScriptArguments, ModelConfig, SFTConfig, SFTTrainer,PPOConfig, PPOTrainer,DPOConfig, DPOTrainer,GRPOConfig, GRPOTrainer,get_peft_config, get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from src.tools.trl import update_trl_config, generate_simple_data_processor# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g.
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g.
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
# - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
# - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
# - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
# - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
# - keyword arguments for `DPOTrainer`: e.g. "ref_model"
# - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()script_arguments = update_trl_config(script_arguments, **config_kwargs)trainer_config = update_trl_config(trainer_config, **config_kwargs)model_config = update_trl_config(model_config, **config_kwargs)peft_config = get_peft_config(model_config)quantization_config = get_quantization_config(model_config)# 2 Load models and tokenizerlogging.info("Load models and tokenizer ...")logging.info(f" - Model: {model_config.model_name_or_path}")tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not "pad_token" in tokenizer.special_tokens_map:tokenizer.add_special_tokens({"pad_token": "[PAD]"})if tokenizer.chat_template is None:tokenizer.chat_template = SIMPLE_CHAT_TEMPLATEmodel = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map = "auto",trust_remote_code = model_config.trust_remote_code,quantization_config = quantization_config,)if peft_config is not None:logging.info("Prepare model for PEFT ...")model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()model = get_peft_model(model, peft_config)if name == "PPO":logging.info("PPO load reward value and reference models ...")# PPO is special! It needs more components!logging.info(f" - Reward model: {trainer_config.reward_model_path}")reward_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)value_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)logging.info(" - Copy reference model ...")ref_model = deepcopy(model)# ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs["reward_model"] = reward_modeltrainer_kwargs["value_model"] = value_modeltrainer_kwargs["ref_model"] = ref_modellogging.info(" - Done!")if data_processor is None:# The data processor of PPO is also different to othersdef data_processor(_data):outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)return {"input_ids": outputs["input_ids"]}# 2 Load datasetlogging.info("Load dataset ...")logging.info(f" - Dataset: {script_arguments.dataset_name}")if data_processor is None:data_processor = generate_simple_data_processor(name)train_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_train_split)eval_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_test_split)train_dataset = train_dataset.map(data_processor, remove_columns=train_dataset.column_names)eval_dataset = eval_dataset.map(data_processor, remove_columns=eval_dataset.column_names)logging.info(f" - Train dataset: {len(train_dataset)}")logging.info(f" - Eval dataset: {len(eval_dataset)}")# 4 Train modellogging.info("Trainer starts ...")trainer = TRLTrainer(model = model,args = trainer_config,train_dataset = train_dataset,eval_dataset = eval_dataset,processing_class = tokenizer,peft_config = peft_config,**trainer_kwargs)trainer.train()logging.info(" - Trainer finishes!")# 5 Save modelif trainer_config.push_to_hub:logging.info(f" - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")trainer.push_to_hub()logging.info(f"Save model to {trainer_config.output_dir}")trainer.save_model(trainer_config.output_dir)
# SFT Pipeline
def sft_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "SFT",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# PPO Pipeline
def ppo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "PPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# DPO Pipeline
def dpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "DPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# GRPO Pipeline
def grpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "GRPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
对应的单元测试脚本在trainer_pipelines.py,这里面涉及3个模型和4个数据集,都可以在huggingface上直接下载得到,model_home
和dataset_home
可根据本地路径进行修改。
# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = ["Qwen/Qwen2.5-0.5B-Instruct","EleutherAI/pythia-1b-deduped","EleutherAI/pythia-160m",
]dataset_names = ["trl-lib/tldr", # train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]"trl-lib/ultrafeedback_binarized", # train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]def sft_pipeline_test():logging.info("SFT unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info("PPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[1])"""EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p=0.0, inplace=False)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_dropout): Dropout(p=0.0, inplace=False)(post_mlp_dropout): Dropout(p=0.0, inplace=False)(attention): GPTNeoXAttention((query_key_value): Linear(in_features=2048, out_features=6144, bias=True)(dense): Linear(in_features=2048, out_features=2048, bias=True))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features=2048, out_features=50304, bias=False))"""dataset_name = os.path.join(dataset_home, dataset_names[0])reward_model_path = os.path.join(model_home, model_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"reward_model_path": reward_model_path,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["query_key_value"],}trainer_kwargs = {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info("DPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "descriptiveness[:500]","dataset_test_split": "descriptiveness[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info("GRPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs = {"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {"reward_funcs": reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
目前单元测试都能通过,关键库版本如下:
accelerate==1.6.0
datasets==3.5.0
peft==0.15.2
torch==2.5.1
transformers==4.51.3
trl==0.17.0
1.1 训练配置参数的设置(Config)
目前模型训练涉及的参数非常的广泛,以前搭积木时代训练模型时,我们一般会自定义一个Config
类用于管理与模型、训练、数据集等相关的参数,但目前Transformers提供了非常好的工具HfArgumentParser
用于管理这些繁杂的参数,根据官方给出的PPO示例(ppo.py)中的写法👇
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
将所有参数分为三类:ScriptArguments, PPOConfig, ModelConfig
使用HfArgumentParser
解析参数的逻辑是(这个过程没有读源码,可能实际上略有出入,但是结果应该是一致的),首先使用默认值初始化ScriptArguments, PPOConfig, ModelConfig
对应的三个对象,然后将执行脚本中传入的参数按名称对应分配给每个对象。
例如,在PPO的例子的执行脚本中:
python -i examples/scripts/ppo/ppo.py \--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \--dataset_train_split descriptiveness \--learning_rate 3e-6 \--output_dir models/minimal/ppo \--per_device_train_batch_size 64 \--gradient_accumulation_steps 1 \--total_episodes 10000 \--model_name_or_path EleutherAI/pythia-1b-deduped \--missing_eos_penalty 1.0
会将dataset_name
和dataset_train_split
的值更新给script_args
,learning_rate
和output_dir
的值更新给model_args
,
不妨可以看看script_args, training_args, model_args
分别包含了哪些参数:
script_args
ScriptArguments(dataset_name=None, dataset_config=None, dataset_train_split='train', dataset_test_split='test', gradient_checkpointing_use_reentrant=False, ignore_bias_buffers=False,
)
training_args
(即ppo_config
)
PPOConfig(output_dir='trainer_output',overwrite_output_dir=False,do_train=False,do_eval=False,do_predict=False,eval_strategy=<IntervalStrategy.NO: 'no'>,prediction_loss_only=False,per_device_train_batch_size=8,per_device_eval_batch_size=8,per_gpu_train_batch_size=None,per_gpu_eval_batch_size=None,gradient_accumulation_steps=1,eval_accumulation_steps=None,eval_delay=0,torch_empty_cache_steps=None,learning_rate=5e-05,weight_decay=0.0,adam_beta1=0.9,adam_beta2=0.999,adam_epsilon=1e-08,max_grad_norm=1.0,num_train_epochs=3.0,max_steps=-1,lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>,lr_scheduler_kwargs={},warmup_ratio=0.0,warmup_steps=0,log_level='passive',log_level_replica='warning',log_on_each_node=True,logging_dir='trainer_output\\runs\\Jun14_11-14-00_LAPTOP-PJP6MGE1',logging_strategy=<IntervalStrategy.STEPS: 'steps'>,logging_first_step=False,logging_steps=500,logging_nan_inf_filter=True,save_strategy=<SaveStrategy.STEPS: 'steps'>,save_steps=500,save_total_limit=None,save_safetensors=True,save_on_each_node=False,save_only_model=False,restore_callback_states_from_checkpoint=False,no_cuda=False,use_cpu=False,use_mps_device=False,seed=42,data_seed=None,jit_mode_eval=False,use_ipex=False,bf16=False,fp16=False,fp16_opt_level='O1',half_precision_backend='auto',bf16_full_eval=False,fp16_full_eval=False,tf32=None,local_rank=0,ddp_backend=None,tpu_num_cores=None,tpu_metrics_debug=False,debug=[],dataloader_drop_last=False,eval_steps=None,dataloader_num_workers=0,dataloader_prefetch_factor=None,past_index=-1,run_name='trainer_output',disable_tqdm=False,remove_unused_columns=True,label_names=None,load_best_model_at_end=False,metric_for_best_model=None,greater_is_better=None,ignore_data_skip=False,fsdp=[],fsdp_min_num_params=0,fsdp_config={'min_num_params': 0,'xla': False,'xla_fsdp_v2': False,'xla_fsdp_grad_ckpt': False},fsdp_transformer_layer_cls_to_wrap=None,accelerator_config=AcceleratorConfig(split_batches=False,dispatch_batches=None,even_batches=True,use_seedable_sampler=True,non_blocking=False,gradient_accumulation_kwargs=None,use_configured_state=False),deepspeed=None,label_smoothing_factor=0.0,optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>,optim_args=None,adafactor=False,group_by_length=False,length_column_name='length',report_to=['wandb'],ddp_find_unused_parameters=None,ddp_bucket_cap_mb=None,ddp_broadcast_buffers=None,dataloader_pin_memory=True,dataloader_persistent_workers=False,skip_memory_metrics=True,use_legacy_prediction_loop=False,push_to_hub=False,resume_from_checkpoint=None,hub_model_id=None,hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>,hub_token=None,hub_private_repo=None,hub_always_push=False,gradient_checkpointing=False,gradient_checkpointing_kwargs=None,include_inputs_for_metrics=False,include_for_metrics=[],eval_do_concat_batches=True,fp16_backend='auto',push_to_hub_model_id=None,push_to_hub_organization=None,push_to_hub_token=None,mp_parameters='',auto_find_batch_size=False,full_determinism=False,torchdynamo=None,ray_scope='last',ddp_timeout=1800,torch_compile=False,torch_compile_backend=None,torch_compile_mode=None,include_tokens_per_second=False,include_num_input_tokens_seen=False,neftune_noise_alpha=None,optim_target_modules=None,batch_eval_metrics=False,eval_on_start=False,use_liger_kernel=False,eval_use_gather_object=False,average_tokens_across_devices=False,dataset_num_proc=None,num_mini_batches=1,total_episodes=None,local_rollout_forward_batch_size=64,num_sample_generations=10,response_length=53,stop_token=None,stop_token_id=None,temperature=0.7,missing_eos_penalty=None,sft_model_path='EleutherAI/pythia-160m',world_size=None,num_total_batches=None,micro_batch_size=None,local_batch_size=None,batch_size=None,local_mini_batch_size=None,mini_batch_size=None,exp_name='ppo_config',reward_model_path='EleutherAI/pythia-160m',model_adapter_name=None,ref_adapter_name=None,num_ppo_epochs=4,whiten_rewards=False,kl_coef=0.05,kl_estimator='k1',cliprange=0.2,vf_coef=0.1,cliprange_value=0.2,gamma=1.0,lam=0.95,ds3_gather_for_generation=True
)
model_args
ModelConfig(model_name_or_path=None, model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=False, lora_r=16, lora_alpha=32, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, lora_task_type='CAUSAL_LM', use_rslora=False, use_dora=False, load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False,
)
其中:
-
script_args
:主要控制数据集的参数。特别地,dataset_name
是required,但这个事情比较奇怪- 如果是在Linux上使用脚本启动的话是必须传入
--dataset_name
或者--dataset-name
参数的,否则会显示error: the following arguments are required: --dataset_name/--dataset-name
(transformers==4.51.3, trl==0.17.0
) - 在Windows系统上即使不传入任何参数也是不会报错的(
transformers==4.52.4, trl==0.18.1
,比Linux上的版本要高一点,但应该不是版本问题)。
- 如果是在Linux上使用脚本启动的话是必须传入
-
training_args
:即PPOConfig
的对象,对应PPOTrainer
中的args
参数,其中包含了非常多与训练参数,具体每个训练器及其配置的参数列表可以直接到HuggingFace上的TRL文档查看(https://huggingface.co/docs/trl),这里例举常用的参数:output_dir
:模型checkpoint导出的路径,adam_xxx
:Adam优化器相关的参数,似乎TRL里所有Trainer的默认优化器都是Adam,不过是可以在PPOTrainer
的参数中定义优化器的(PPOTrainer
的优化器参数optimizers
是需要传入两个变量(optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
)),一个是优化器,另一个是学习率的规划器。learning_rate
:学习率,kl_coef
:KL散度的惩罚系数,越大相当于策略更新幅度越小,越小策略就可能更新幅度越快,
-
model_args
对应的是PPOTrainer
的peft_config
参数,里面也有一些和模型初始化相关的参数,包括量化与PEFT相关的参数也是从中提取得到的,比如在base_pipeline
脚本中:
...
# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g.
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g.
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
# - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
# - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
# - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
# - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
# - keyword arguments for `DPOTrainer`: e.g. "ref_model"
# - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()script_arguments = update_trl_config(script_arguments, **config_kwargs)trainer_config = update_trl_config(trainer_config, **config_kwargs)model_config = update_trl_config(model_config, **config_kwargs)peft_config = get_peft_config(model_config)quantization_config = get_quantization_config(model_config)...
peft_config
和quantization_config
都是可以现成的函数生成的,它们被定义在trl
项目根目录下的trainer/utils.py
中:
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:if model_args.load_in_4bit:quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,bnb_4bit_quant_storage=model_args.torch_dtype,)elif model_args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_8bit=True,)else:quantization_config = Nonereturn quantization_configdef get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":if model_args.use_peft is False:return Noneif not is_peft_available():raise ValueError("You need to have PEFT library installed in your environment, make sure to install `peft`. ""Make sure to run `pip install -U peft`.")peft_config = LoraConfig(task_type=model_args.lora_task_type,r=model_args.lora_r,target_modules=model_args.lora_target_modules,lora_alpha=model_args.lora_alpha,lora_dropout=model_args.lora_dropout,bias="none",use_rslora=model_args.use_rslora,use_dora=model_args.use_dora,modules_to_save=model_args.lora_modules_to_save,)return peft_config
从上面的源码可以看出,如果model_config.use_peft=False
,则默认不启用PEFT,得到的peft_config
也就是None
,同理是否采用量化取决于load_in_4bit
和load_in_8bit
是否至少有一个是True
1.2 加载分词器与模型(tokenizer & model)
回到base_pipeline
中对应的部分:
...# 2 Load models and tokenizerlogging.info("Load models and tokenizer ...")logging.info(f" - Model: {model_config.model_name_or_path}")tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not "pad_token" in tokenizer.special_tokens_map:tokenizer.add_special_tokens({"pad_token": "[PAD]"})if tokenizer.chat_template is None:tokenizer.chat_template = SIMPLE_CHAT_TEMPLATEmodel = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map = "auto",trust_remote_code = model_config.trust_remote_code,quantization_config = quantization_config,)if peft_config is not None:logging.info("Prepare model for PEFT ...")model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()model = get_peft_model(model, peft_config)...
这里主要是几个零碎的注意点:
- 分词器的词汇表中一定要包含
pad_token
,具体可以通过tokenizer.special_tokens_map
,否则在后面trainer.train()
时会报错,提示你设置pad_token
- 关于
tokenizer.chat_template
的问题,这个到接下来的数据处理部分会详细说明,针对对话类的任务,一般是会采用chat_template
来处理数据集的,因此可以用trl.trainer.utils
中提供的SIMPLE_CHAT_TEMPLATE
作为缺省 - 模型可以通过量化加载(即
quantization_config
不为None
时),这样会节约内存,此时一般都会做如下的设置👇(如果不是量化加载的话,可能加上这些配置也没什么关系,所以我就这么写了)- 这里的
pretraining_tp
是张量并行的意思,数值越大,就会并行的数量越多,当然只有多卡才能并行,单卡只能设为1 use_cache
应该是指(大概)做model.generate()
生成时是否进行kv_cache
缓存model.enable_input_require_grads()
也是节约内存的方法,我记得是缓存梯度以达到更快的反向传播运算的作用。
- 这里的
...model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()...
然后对于PPOTrainer
而言,非常特殊,它的构造参数还需要reward_model, value_model, ref_model
三个参数,因此在base_pipeline
中做了额外处理:
if name == "PPO":logging.info("PPO load reward value and reference models ...")# PPO is special! It needs more components!logging.info(f" - Reward model: {trainer_config.reward_model_path}")reward_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)value_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)logging.info(" - Copy reference model ...")ref_model = deepcopy(model)# ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs["reward_model"] = reward_modeltrainer_kwargs["value_model"] = value_modeltrainer_kwargs["ref_model"] = ref_modellogging.info(" - Done!")
这里注意一下:
- 一般初始化
ref_model
和目标模型是完全一样的,这里测试了两种方法,一种是直接ref_model = deepcopy(model)
,这确实是可行的,不会发生unpickled之类的问题,另一种方法则是重构一个和目标模型结构相同的空模型,然后加载状态字典:
这个方法在# ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())
use_peft=False
时是可行的,但是因为这里可能是一个peft_model
,因此model.__class__
的参数还需要一个peft_config
,直接这样写是会报错的 - 使用
HfArgumentParser
加载PPOConfig
参数时,ppo_config.reward_model_path
是有默认值的(可以看上面的training_args
),即'EleutherAI/pythia-160m'
,这个奖励模型必须是可以被AutoModelForSequenceClassification
类型加载的 value_model
通常设置为和reward_model
相同,至少在官方的PPO示例中是这样的,其实也可以解释,我的理解是:
A ( s , a ) = Q ( s , a ) − V ( s ) A(s,a)=Q(s,a)-V(s) A(s,a)=Q(s,a)−V(s)
中,这里 V ( s ) V(s) V(s)即value_model
, Q ( s , a ) = r ( s , a ) + γ V ( s ′ ) Q(s,a)=r(s,a)+\gamma V(s') Q(s,a)=r(s,a)+γV(s′)中的 r ( s , a ) r(s,a) r(s,a)即奖励函数,在CAUSAL_LM
的语境下,其实就是句子 s s s加了一个单词 a a a后的奖励,本质上都是对一句话进行评分。虽然说得通,但是感觉也挺奇怪的。
1.3 数据集分割与字段问题(dataset)
目前主流的数据集加载都是直接用HuggingFace的datasets.load_dataset
方法加载,标准的数据格式就是每个样本的格式为{“column_1”: <data_1>, "column_2": <data_2>}
,整个数据集类似jsonl的格式。
针对目前已知的几个训练器,它们对数据集的格式要求大概是这样的:
-
SFTTrainer
:字段要包含prompt
和completion
,注意,训练时只会在completion
部分计算损失,prompt
部分是不会计算损失的- 但是其实我发现如果数据集中只有
text
或者input
+target
这种字段时也是可行的,需要仔细读源码才行。
- 但是其实我发现如果数据集中只有
-
DPOTrainer
:这个就很简单,一般就是prompt
+chosen
+rejected
,但是在官方示例使用的数据集trl-lib/ultrafeedback_binarized
中,字段却是chosen, rejected, score_chosen, score_rejected
,也就是说其实prompt
可能并非必须,本来也确实可以直接用空字符串替代,然后数据集中如果没有回答进行评分的话,可能还有一个默认的评分机制在训练器里面。 -
GRPOTrainer
:这个我也是根据官方示例来的,我看到它使用的数据集是trl-lib/tldr
,这是一个典型的只有prompt
和completion
两个字段的数据集 -
PPOTrainer
:这个最为特殊,之前DeepSeek一直写不对的原因也在于此-
在官方给出的PPO运行脚本(ppo.py)中
def prepare_dataset(dataset, tokenizer):"""pre-tokenize the dataset before training; only collate during training"""def tokenize(element):outputs = tokenizer(element[dataset_text_field],padding=False,)return {"input_ids": outputs["input_ids"]}return dataset.map(tokenize,batched=True,remove_columns=dataset.column_names,num_proc=training_args.dataset_num_proc,)
注意到这里是直接处理成分词后的
input_ids
格式 -
但是,我本来以为PPO和GRPO是类似的,因此理论上在数据集的格式要求上应该也差不了太多,因此我觉得可能
prompt
+completion
也是可行的,但是测试下来会报错,报错提示是一定要求是带input_ids
字段的。 -
因此我单独给PPO写了一个数据处理的
data_processor
,而其余的都是很简单的:# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO" def generate_simple_data_processor(name, **kwargs):if name in ["SFT", "GRPO"]:def _data_processor(_data):return {"prompt": _data["prompt"], "completion": _data["completion"]}elif name == "PPO":tokenizer = kwargs.get("tokenizer")def _data_processor(_data):outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)return {"input_ids": outputs["input_ids"]}elif name == "DPO":def _data_processor(_data):return {"prompt": _data["prompt"], "chosen": _data["chosen"], "rejected": _data["rejected"]}else:raise NotImplementedError(name)return _data_processor
-
当然一些写法也会使用apply_chat_template
,这通常是用于交互式对话任务,都是可行的:
def _data_processor(_data):_message = [{"role": "system", "content": "You are an AI assistant developped by CY"},{"role": "user", "content": _data[dataset_input_column]},{"role": "assistant", "content": _data[dataset_target_column]},]_prompt = tokenizer.apply_chat_template(_message, tokenize=False)return {"text": _prompt}
然后一个小细节是在用dataset.map
进行数据处理时一般会设置参数emove_columns=dataset.column_names
,以丢弃原先不必要的字段,防止Trainer在运行时错用了其他的字段进行训练。
1.4 训练与保存(checkpoint)
base_pipeline
中最后一部分训练完然后保存模型,顺利地话就完整地跑通:
# 4 Train modellogging.info("Trainer starts ...")trainer = TRLTrainer(model = model,args = trainer_config,train_dataset = train_dataset,eval_dataset = eval_dataset,processing_class = tokenizer,peft_config = peft_config,**trainer_kwargs)trainer.train()logging.info(" - Trainer finishes!")# 5 Save modelif trainer_config.push_to_hub:logging.info(f" - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")trainer.push_to_hub()logging.info(f"Save model to {trainer_config.output_dir}")trainer.save_model(trainer_config.output_dir)
这里值得注意的几个问题:
-
在
trainer_config
(即SFTConfig, PPOConfig, DPOConfig, GRPOConfig
之类的对象)中,有一个参数trainer_config.report_to
,这个不设置的话默认是会上传到W&B的,网络不支持访问的话是会在训练到checkpoint的时候发生网络错误的,因此一般会设置成"none"
,或者也可以用"tensorboard"
- 不过就训练绘图的话,训练结束到
trainer_config.output_dir
中找到checkpoint-xxx
文件夹,里面会有trainer_state.json
文件,然后自己根据里面的数据绘图即可,例如:# Plot dynamics of TRL trainer state def plot_trl_dynamics(trainer_state_path):with open(trainer_state_path, 'r', encoding="utf8") as f:data = json.load(f)log_history = data["log_history"]steps = [entry["step"] for entry in log_history]episodes = [entry["episode"] for entry in log_history]epochs = [entry["epoch"] for entry in log_history]policy_loss = [entry["loss/policy_avg"] for entry in log_history]value_loss = [entry["loss/value_avg"] for entry in log_history]lrs = [entry["lr"] for entry in log_history]entropys = [entry["objective/entropy"] for entry in log_history]kls = [entry["objective/kl"] for entry in log_history]non_score_rewards = [entry["objective/non_score_reward"] for entry in log_history]rlhf_rewards = [entry["objective/rlhf_reward"] for entry in log_history]scores = [entry["objective/scores"] for entry in log_history]plt.figure(figsize=(8, 8))ax_1 = plt.subplot(2, 2, 1)ax_2 = plt.subplot(4, 2, 2)ax_3 = plt.subplot(4, 2, 4)ax_4 = plt.subplot(2, 2, 3)ax_5 = plt.subplot(2, 2, 4)ax_1.plot(steps, policy_loss, label="Policy Loss")ax_1.plot(steps, value_loss, label="Value Loss", linestyle="--")ax_1.set_xlabel("Step"), ax_1.set_ylabel("Loss"), ax_1.legend()ax_1.set_title("Policy and Value Loss")# ------------------------------------------------------------------ax_2.plot(steps, kls, label="objective/kl")ax_2.set_xlabel("Step"), ax_2.set_ylabel("KL"), ax_2.legend()ax_2.set_title("KL Curve")# ------------------------------------------------------------------ax_3.plot(steps, entropys, label="objective/entropy")ax_3.set_xlabel("Step"), ax_3.set_ylabel("Entropy"), ax_3.legend()ax_3.set_title("Entropy Curve")# ------------------------------------------------------------------ax_4.plot(steps, lrs, label="Learning Rate")ax_4.set_xlabel("Step"), ax_4.set_ylabel("Learning Rate"), ax_4.legend()ax_4.set_title("Learning Rate Curve")# ------------------------------------------------------------------ax_5.plot(steps, non_score_rewards, label="objective/non_score_reward", linestyle="--")ax_5.plot(steps, rlhf_rewards, label="objective/rlhf_reward", linestyle="--")ax_5.plot(steps, scores, label="objective/scores")ax_5.set_xlabel("Step"), ax_5.set_ylabel("Score/Reward"), ax_5.legend()ax_5.set_title("Reward and Score")plt.show()print(steps)print('-' * 64)print(episodes)print('-' * 64)print(epochs)
- 不过就训练绘图的话,训练结束到
-
关于W&B的使用:如果网络支持,使用
wandb
确实也是不错的选择:登录 wandb
方法 1:交互式登录(推荐)
运行以下代码,会弹出浏览器或要求输入 API Key:
import wandb wandb.login()
然后按提示操作即可。
方法 2:直接提供 API Key
import wandb wandb.login(key="your_api_key_here") # 替换成你的 wandb API Key
⚠️ 注意:不要直接在代码里硬编码 API Key(特别是公开代码时),建议使用环境变量或配置文件。
方法 3:通过环境变量设置(适合脚本/服务器)
export WANDB_API_KEY=your_api_key_here
然后在 Python 中直接
wandb.init()
即可。
步骤 3:在 Trainer 中启用 wandb
方式 1:自动记录(默认)
HuggingFace
Trainer
默认会使用wandb
(如果已安装并登录):from transformers import Trainer, TrainingArguments training_args = TrainingArguments(output_dir="./results",report_to="wandb", # 明确指定使用 wandb(默认已经是这样)logging_dir="./logs",# 其他训练参数... ) trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset, ) trainer.train() # 自动记录到 wandb
方式 2:手动初始化 wandb(更灵活控制)
import wandb # 初始化 wandb 并设置项目名称 wandb.init(project="my-ppo-project", name="experiment-1") # 然后正常训练 trainer.train() # 训练结束后(可选) wandb.finish()
这样你可以自定义
wandb
的配置(如project
、name
、config
等)。
额外配置(可选)
1. 记录超参数
config = {"learning_rate": 2e-5,"batch_size": 32,"epochs": 3,
}
wandb.init(project="my-project", config=config)
训练时,这些参数会显示在 wandb 面板中。
2. 自定义记录指标
# 在训练循环中记录自定义指标
wandb.log({"reward": current_reward, "loss": current_loss})
3. 恢复中断的训练
wandb.init(id="之前运行的ID", resume="must")
- 训练开始后,终端会显示类似:
Tracking run with wandb, view at https://wandb.ai/your-username/your-project/runs/xxx
- 打开链接即可查看实时训练数据。
1.5 单元测试
单元测试脚本trainer_pipelines.py:
# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = ["Qwen/Qwen2.5-0.5B-Instruct","EleutherAI/pythia-1b-deduped","EleutherAI/pythia-160m",
]dataset_names = ["trl-lib/tldr", # train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]"trl-lib/ultrafeedback_binarized", # train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]def sft_pipeline_test():logging.info("SFT unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info("PPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[1])"""EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p=0.0, inplace=False)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_dropout): Dropout(p=0.0, inplace=False)(post_mlp_dropout): Dropout(p=0.0, inplace=False)(attention): GPTNeoXAttention((query_key_value): Linear(in_features=2048, out_features=6144, bias=True)(dense): Linear(in_features=2048, out_features=2048, bias=True))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features=2048, out_features=50304, bias=False))"""dataset_name = os.path.join(dataset_home, dataset_names[0])reward_model_path = os.path.join(model_home, model_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"reward_model_path": reward_model_path,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["query_key_value"],}trainer_kwargs = {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info("DPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "descriptiveness[:500]","dataset_test_split": "descriptiveness[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info("GRPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs = {"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {"reward_funcs": reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
这里主要就是grpo_pipeline_test
中的reward_funcs
,其他几个测试函数都没什么好说的,很好理解
GRPOTrainer
中有required参数reward_funcs
,这个在TRL文档中给出的例子是:
def reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]
这个奖励的意思是会给长回答更高的得分。
注意这里的**kwargs
是不可以省略的,因为GRPOTrainer
调用reward_funcs
会传入其他参数,比如prompts
或者ground_truths
,在TRL的GRPO文档中有很多例子。
2 一些其他的问题
这部分摘取一些和DeepSeek讨论的问题
2.1 DataProcessor与DataCollator
关于TRL的Trainer
的一个问题,因为Trainer
是可以传入data_collator
参数的,但是一般来说我们都会在向Trainer
传入数据集(如train_dataset
)之前,先写好一个函数用于预处理数据集,比如train_dataset = train.dataset.map(prompt_formatter)
这样,我的理解其实这里prompt_formatter何data_collator的作用是一样的(我觉得它们处理数据的逻辑也是一样,就是传入一个样本,然后做数据处理),因此如果在Trainer外面已经做了train_dataset.map(prompt_formatter)
的操作,就不需要再传入data_collator了。反之也是一样,我可以传入data_collator
,而无需进行train_dataset.map(prompt_formatter)
的操作
以下回答由AI生成:
你的理解基本正确,但有一些细节需要注意。prompt_formatter
(在dataset.map
中使用)和data_collator
确实都涉及数据处理,但它们的作用时机和目的有所不同。下面我来详细解释它们的区别以及何时使用它们:
2.1.1 dataset.map(prompt_formatter)
:
- 作用时机: 在数据加载时,预处理阶段(即在训练开始之前)对数据集中的每个样本进行处理。
- 典型用途:
- 将原始文本转换为模型需要的输入格式(例如添加特殊token、构建prompt-completion结构)。
- 对数据进行过滤(通过
remove_columns
删除不需要的字段)。 - 执行一次性的、确定性的转换(例如tokenization,如果你不启用动态padding)。
- 特点:
- 处理结果是静态的(处理后的数据会被缓存,除非设置
load_from_cache_file=False
)。 - 适合所有样本共享的逻辑(例如统一的prompt模板)。
- 处理结果是静态的(处理后的数据会被缓存,除非设置
2.1.2 data_collator
:
- 作用时机: 在训练过程中,动态批处理阶段(即在每个batch生成时)对样本进行实时处理。
- 典型用途:
- 动态padding(将同一batch内的样本padding到相同长度)。
- 随机数据增强(例如随机masking,适用于MLM任务)。
- 构造特殊输入(例如为RLHF构建奖励模型的输入格式)。
- 特点:
- 处理是动态的(每次batch采样时重新执行)。
- 适合需要随机性或依赖batch内其他样本的操作(例如padding长度取决于当前batch)。
2.1.3 何时选择哪种方式?
-
只用
dataset.map
:- 如果你的数据处理是确定性的、不需要动态调整(例如固定的prompt模板),且不需要batch级别的操作(如padding),可以只使用
dataset.map
。 - 但注意:如果涉及变长序列,仍需
data_collator
处理padding(除非你已手动填充到固定长度)。
- 如果你的数据处理是确定性的、不需要动态调整(例如固定的prompt模板),且不需要batch级别的操作(如padding),可以只使用
-
只用
data_collator
:- 如果数据处理需要动态性(例如随机masking)或依赖batch上下文(如padding),可以只用
data_collator
。 - 但注意:这样可能导致训练时重复处理数据(例如每次epoch重复tokenization),影响效率。
- 如果数据处理需要动态性(例如随机masking)或依赖batch上下文(如padding),可以只用
-
联合使用:
- 常见做法是:
- 用
dataset.map
处理静态逻辑(例如tokenization、添加固定prompt)。 - 用
data_collator
处理动态逻辑(例如padding、随机增强)。
- 用
- 例如:
def prompt_formatter(example):example["input_ids"] = tokenizer(example["text"]).input_idsreturn example # 不padding,保留原始长度train_dataset = raw_dataset.map(prompt_formatter, batched=False)data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True) # 动态padding和masking
- 常见做法是:
-
如果
prompt_formatter
已经完成了所有必要处理(包括padding到固定长度),且不需要任何动态逻辑,确实可以跳过data_collator
(或传入default_data_collator
,它仅做简单的batch堆叠)。 -
但如果需要动态操作(如padding),即使有
dataset.map
,仍需data_collator
。 -
二者并非互斥,而是互补的。
dataset.map
用于离线预处理,data_collator
用于在线批处理。 -
选择取决于你的数据处理需求是否依赖动态上下文或随机性。
2.2 关于PartialState
accelerate库里有一个方法:PartialState().local_main_process_first()
,我看到这个方法会用来加载数据集,这个看起来像是并行加载数据集的意思,你可以解释一下这个函数是用来做什么的吗?
注:这个方法本来是在PPO官方示例中使用的
PartialState().local_main_process_first()
是 Hugging Face accelerate
库中用于多进程并行处理的一个关键方法,尤其在分布式训练或推理时控制数据加载、文件处理等操作的执行顺序。它的核心作用是确保某些操作(如数据集加载或文件预处理)只在**主进程(main process)**上优先执行,避免多进程竞争或重复操作。
2.2.1 核心功能
local_main_process_first()
返回一个上下文管理器(context manager),其行为如下:- 主进程(local main process):会立即执行上下文内的代码块。
- 其他进程(非主进程):会等待主进程完成后,再跳过或执行相同操作(取决于具体逻辑)。
- 目的:避免多进程同时操作共享资源(如文件读取、数据集缓存)导致的冲突或冗余计算。
2.2.2 典型使用场景
(1) 加载数据集时避免重复缓存
在分布式训练中,如果多个进程同时尝试缓存数据集(例如通过 datasets
库),可能会导致文件写入冲突。通过 local_main_process_first()
可以确保只有主进程生成缓存,其他进程直接复用:
from accelerate import PartialStatestate = PartialState() # 初始化分布式状态# 只有主进程先缓存数据集,其他进程等待后直接加载缓存
with state.local_main_process_first():dataset = load_dataset("imdb") # 假设这是耗时操作
(2) 文件预处理或下载
如果预处理(如分词、数据转换)只需要执行一次:
with state.local_main_process_first():tokenized_data = dataset.map(tokenize_function, batched=True) # 主进程先处理
# 其他进程会自动等待主进程完成,然后直接使用处理后的数据
2.2.3 参数与底层机制
PartialState
:是accelerate
中管理分布式状态的类,自动检测当前进程的分布式环境(如多GPU、TPU等)。local_main_process
:指当前节点(机器)上的主进程(通常是全局主进程的一个子集,在多机训练时每台机器有自己的local_main_process
)。- 同步机制:底层通过
torch.distributed
或类似库实现进程间同步,确保非主进程等待主进程完成。
2.2.4 与类似方法的区别
方法/属性 | 作用范围 | 用途 |
---|---|---|
local_main_process_first() | 当前节点的主进程优先 | 避免文件/缓存操作的竞争 |
main_process_first() | 全局主进程优先(跨所有节点) | 适用于全局只需执行一次的操作 |
on_main_process() | 仅在主进程执行 | 完全跳过其他进程(如日志记录) |
2.2.5 完整示例(分布式训练中的数据加载)
from accelerate import PartialState
from datasets import load_dataset# 初始化分布式状态(自动检测环境)
state = PartialState()# 只有主进程优先加载并缓存数据,其他进程等待
with state.local_main_process_first():dataset = load_dataset("glue", "mrpc", split="train")tokenized_data = dataset.map(lambda x: tokenizer(x["sentence"]), batched=True)# 所有进程继续同步训练
for batch in tokenized_data:outputs = model(**batch)...
2.2.6 注意事项
- 缓存路径一致性:确保所有进程能访问相同的缓存路径(如共享存储)。
- 性能影响:主进程的预处理时间会被计入总耗时,但避免了重复操作的开销。
- 错误处理:如果主进程预处理失败,其他进程会一直等待,需添加超时或错误处理逻辑。
2.2.7 总结
- 何时使用:在分布式环境中,任何需要单例执行的操作(如数据缓存、文件下载、预处理)都应使用
local_main_process_first()
。 - 优势:减少资源竞争,避免冗余计算,确保分布式任务的高效协同。
如果需要进一步了解 accelerate
的分布式策略(如多机训练),可以继续探讨!
2.3 DPO和PPO的reference_model的区别
这个是阅读TRL文档发现的一个问题,DPOTrainer
的ref_model
参数是有默认值None
的,但是PPOTrainer
的ref_model
参数没有默认值。这意味着必须给PPOTrainer
传入ref_model
参数,尽管一般都是跟目标模型一样。
你提出了一个很好的问题!DPO(Direct Preference Optimization)虽然确实不属于传统强化学习的范畴,但它仍然保留了reference model,这是出于特定的设计考虑。下面我详细解释原因:
2.3.1 DPO的核心思想与Reference Model的作用
DPO通过直接优化偏好数据来规避强化学习,但其目标函数(源自Bradley-Terry模型)隐式地包含了一个KL散度约束:
L DPO = − E [ log σ ( β log π θ ( y w ∣ x ) π ref ( y w ∣ x ) − β log π θ ( y l ∣ x ) π ref ( y l ∣ x ) ) ] \mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] LDPO=−E[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
其中:
- π θ \pi_\theta πθ 是待训练的策略(当前模型)。
- π ref \pi_{\text{ref}} πref 是reference model(通常初始化为预训练模型,例如SFT模型)。
- β \beta β 是控制偏离参考模型程度的超参数。
Reference Model的作用:
- 防止过度偏离初始策略:KL散度约束(通过reference model体现)避免模型为迎合偏好数据而退化(例如生成极端或无意义的输出)。
- 隐式正则化:类似于PPO中的KL惩罚,但DPO通过直接在损失函数中对比当前模型和reference模型的输出概率实现。
2.3.2 为什么DPO需要Reference Model,尽管它不是RL?
- 数学推导的必然性:DPO的公式是从逆向强化学习的目标推导而来的,其中reference model代表初始策略(即RL中的 π old \pi_{\text{old}} πold)。去掉reference model会导致目标函数失去约束,模型可能过拟合偏好数据。
- 实践中的稳定性:实验表明,去掉reference model后,模型容易过度优化偏好信号(例如始终输出短答案或重复模式)。
2.3.3 TRL的DPOTrainer中的ref_model
在HuggingFace TRL库中,DPOTrainer
确实需要传入ref_model
参数,原因如下:
- 计算概率比:DPO需要比较当前模型和reference模型对正/负样本的概率(即(\log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}))。
- 冻结参考模型:
ref_model
在训练过程中是冻结的(不更新),仅作为基准。 - 内存优化:实际实现中,
ref_model
可能与当前模型共享部分结构(例如共享底层Transformer),但通过梯度截断避免更新。
2.3.4 DPO vs PPO的Reference Model
特性 | PPO (RL) | DPO (Non-RL) |
---|---|---|
Reference Model | 旧策略 π old \pi_{\text{old}} πold | 预训练模型 π ref \pi_{\text{ref}} πref |
更新方式 | 定期同步(on-policy要求) | 完全冻结 |
作用 | 重要性采样校正 + KL约束 | 隐式KL约束 + 防止过拟合 |
2.3.5 如果没有Reference Model会怎样?
假设令 π ref \pi_{\text{ref}} πref为均匀分布,则DPO退化为直接最大化偏好对的似然,可能导致:
- 模型过度自信(overconfident)。
- 生成内容多样性下降(模式坍塌)。
- 对噪声偏好数据过拟合。
DPO虽然跳过了显式的强化学习框架,但通过保留reference model,继承了RL中“约束策略更新”的核心思想。这是其稳定性和性能的关键设计,也是它区别于纯监督学习(如直接微调)的重要特征。
PS:我感觉这个回答没有解释我的问题,答非所问了