当前位置: 首页 > java >正文

【慢摹】TRL训练器细节详解(SFT/PPO/DPO/GRPO)

序言

本文总结一下目前TRL典型的训练器的实现细节(SFT,PPO,DPO,GRPO),也是对上一文【速写】PPOTrainer样例与错误思考(少量DAPO)的补充

目前DeepSeek关于各个训练器细节的掌握,尤其是PPOTrainer的问题,依然回答得很差,这个在上文中已经详细指出,它写的代码大多数都是跑不通的,而官方给出的PPO示例(ppo.py)似乎也有一些瑕疵,本文将会一一指出与各位探讨。

目前笔者写了一个对于4个典型训练器(SFT,PPO,DPO,GRPO)都适用的一个base_pipeline,以及对应的单元测试模块,其中PPOTrainer是最为繁琐的一个训练器,它需要更多的模型(奖励模型,价值模型,参考模型),并且对应的训练数据集train_dataset的处理方式与其他训练器存在显著区别。根据这份base_pipeline的代码,我将逐一探讨TRL训练器中的细节问题,谨以抛砖引玉,欢迎探讨。


文章目录

  • 序言
  • 1 TRL通用的训练器Pipeline测试及其细节说明
    • 1.1 训练配置参数的设置(Config)
    • 1.2 加载分词器与模型(tokenizer & model)
    • 1.3 数据集分割与字段问题(dataset)
    • 1.4 训练与保存(checkpoint)
    • 1.5 单元测试
  • 2 一些其他的问题
    • 2.1 DataProcessor与DataCollator
      • 2.1.1 `dataset.map(prompt_formatter)`:
      • 2.1.2 `data_collator`:
      • 2.1.3 何时选择哪种方式?
    • 2.2 关于PartialState
      • 2.2.1 核心功能
      • 2.2.2 典型使用场景
      • 2.2.3 参数与底层机制
      • 2.2.4 与类似方法的区别
      • 2.2.5 完整示例(分布式训练中的数据加载)
      • 2.2.6 注意事项
      • 2.2.7 总结
    • 2.3 DPO和PPO的reference_model的区别
      • 2.3.1 DPO的核心思想与Reference Model的作用
      • 2.3.2 **为什么DPO需要Reference Model,尽管它不是RL?**
      • 2.3.3 TRL的DPOTrainer中的`ref_model`
      • 2.3.4 **DPO vs PPO的Reference Model**
      • 2.3.5 **如果没有Reference Model会怎样?**


1 TRL通用的训练器Pipeline测试及其细节说明

完整的项目在GitHub@caoyang-sufe/easyllm,以下我们先来看下面代码中的base_pipeline,其他4个pipeline(sft_pipelineppo_pipelinedpo_pipelinegrpo_pipeline)都是直接调用它使用。

对应的脚本在trainer.py

这个pipeline很重要,接下来所有的讨论都是围绕它逐行展开的。

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport wandb
import logging
from copy import deepcopy
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, HfArgumentParser
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import (ScriptArguments, ModelConfig, SFTConfig, SFTTrainer,PPOConfig, PPOTrainer,DPOConfig, DPOTrainer,GRPOConfig, GRPOTrainer,get_peft_config, get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from src.tools.trl import update_trl_config, generate_simple_data_processor# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. 
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. 
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
#   - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
#   - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
#   - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
#   - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
#   - keyword arguments for `DPOTrainer`: e.g. "ref_model"
#   - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()script_arguments = update_trl_config(script_arguments, **config_kwargs)trainer_config = update_trl_config(trainer_config, **config_kwargs)model_config = update_trl_config(model_config, **config_kwargs)peft_config = get_peft_config(model_config)quantization_config = get_quantization_config(model_config)# 2 Load models and tokenizerlogging.info("Load models and tokenizer ...")logging.info(f"  - Model: {model_config.model_name_or_path}")tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not "pad_token" in tokenizer.special_tokens_map:tokenizer.add_special_tokens({"pad_token": "[PAD]"})if tokenizer.chat_template is None:tokenizer.chat_template = SIMPLE_CHAT_TEMPLATEmodel = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map = "auto",trust_remote_code = model_config.trust_remote_code,quantization_config = quantization_config,)if peft_config is not None:logging.info("Prepare model for PEFT ...")model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()model = get_peft_model(model, peft_config)if name == "PPO":logging.info("PPO load reward value and reference models ...")# PPO is special! It needs more components!logging.info(f"  - Reward model: {trainer_config.reward_model_path}")reward_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)value_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)logging.info("  - Copy reference model ...")ref_model = deepcopy(model)# ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs["reward_model"] = reward_modeltrainer_kwargs["value_model"] = value_modeltrainer_kwargs["ref_model"] = ref_modellogging.info("  - Done!")if data_processor is None:# The data processor of PPO is also different to othersdef data_processor(_data):outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)return {"input_ids": outputs["input_ids"]}# 2 Load datasetlogging.info("Load dataset ...")logging.info(f"  - Dataset: {script_arguments.dataset_name}")if data_processor is None:data_processor = generate_simple_data_processor(name)train_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_train_split)eval_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_test_split)train_dataset = train_dataset.map(data_processor, remove_columns=train_dataset.column_names)eval_dataset = eval_dataset.map(data_processor, remove_columns=eval_dataset.column_names)logging.info(f"  - Train dataset: {len(train_dataset)}")logging.info(f"  - Eval dataset: {len(eval_dataset)}")# 4 Train modellogging.info("Trainer starts ...")trainer = TRLTrainer(model = model,args = trainer_config,train_dataset = train_dataset,eval_dataset = eval_dataset,processing_class = tokenizer,peft_config = peft_config,**trainer_kwargs)trainer.train()logging.info("  - Trainer finishes!")# 5 Save modelif trainer_config.push_to_hub:logging.info(f"  - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")trainer.push_to_hub()logging.info(f"Save model to {trainer_config.output_dir}")trainer.save_model(trainer_config.output_dir)
# SFT Pipeline
def sft_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "SFT",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# PPO Pipeline
def ppo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "PPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# DPO Pipeline
def dpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "DPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)
# GRPO Pipeline
def grpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name = "GRPO",data_processor = data_processor,config_kwargs = config_kwargs,trainer_kwargs = trainer_kwargs,)

对应的单元测试脚本在trainer_pipelines.py,这里面涉及3个模型和4个数据集,都可以在huggingface上直接下载得到,model_homedataset_home可根据本地路径进行修改。

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = ["Qwen/Qwen2.5-0.5B-Instruct","EleutherAI/pythia-1b-deduped","EleutherAI/pythia-160m",
]dataset_names = ["trl-lib/tldr",	# train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]"trl-lib/ultrafeedback_binarized",	# train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]def sft_pipeline_test():logging.info("SFT unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info("PPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[1])"""EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p=0.0, inplace=False)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_dropout): Dropout(p=0.0, inplace=False)(post_mlp_dropout): Dropout(p=0.0, inplace=False)(attention): GPTNeoXAttention((query_key_value): Linear(in_features=2048, out_features=6144, bias=True)(dense): Linear(in_features=2048, out_features=2048, bias=True))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features=2048, out_features=50304, bias=False))"""dataset_name = os.path.join(dataset_home, dataset_names[0])reward_model_path = os.path.join(model_home, model_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"reward_model_path": reward_model_path,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["query_key_value"],}trainer_kwargs = {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info("DPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "descriptiveness[:500]","dataset_test_split": "descriptiveness[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info("GRPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs = {"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {"reward_funcs": reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

目前单元测试都能通过,关键库版本如下:

accelerate==1.6.0
datasets==3.5.0
peft==0.15.2
torch==2.5.1
transformers==4.51.3
trl==0.17.0

1.1 训练配置参数的设置(Config)

目前模型训练涉及的参数非常的广泛,以前搭积木时代训练模型时,我们一般会自定义一个Config类用于管理与模型训练数据集等相关的参数,但目前Transformers提供了非常好的工具HfArgumentParser用于管理这些繁杂的参数,根据官方给出的PPO示例(ppo.py)中的写法👇

parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()

将所有参数分为三类:ScriptArguments, PPOConfig, ModelConfig

使用HfArgumentParser解析参数的逻辑是(这个过程没有读源码,可能实际上略有出入,但是结果应该是一致的),首先使用默认值初始化ScriptArguments, PPOConfig, ModelConfig对应的三个对象,然后将执行脚本中传入的参数按名称对应分配给每个对象。

例如,在PPO的例子的执行脚本中:

python -i examples/scripts/ppo/ppo.py \--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \--dataset_train_split descriptiveness \--learning_rate 3e-6 \--output_dir models/minimal/ppo \--per_device_train_batch_size 64 \--gradient_accumulation_steps 1 \--total_episodes 10000 \--model_name_or_path EleutherAI/pythia-1b-deduped \--missing_eos_penalty 1.0

会将dataset_namedataset_train_split的值更新给script_argslearning_rate output_dir的值更新给model_args

不妨可以看看script_args, training_args, model_args分别包含了哪些参数:

  • script_args
ScriptArguments(dataset_name=None, dataset_config=None, dataset_train_split='train', dataset_test_split='test', gradient_checkpointing_use_reentrant=False, ignore_bias_buffers=False)
  • training_args(即ppo_config
PPOConfig(output_dir='trainer_output',overwrite_output_dir=False,do_train=False,do_eval=False,do_predict=False,eval_strategy=<IntervalStrategy.NO: 'no'>,prediction_loss_only=False,per_device_train_batch_size=8,per_device_eval_batch_size=8,per_gpu_train_batch_size=None,per_gpu_eval_batch_size=None,gradient_accumulation_steps=1,eval_accumulation_steps=None,eval_delay=0,torch_empty_cache_steps=None,learning_rate=5e-05,weight_decay=0.0,adam_beta1=0.9,adam_beta2=0.999,adam_epsilon=1e-08,max_grad_norm=1.0,num_train_epochs=3.0,max_steps=-1,lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>,lr_scheduler_kwargs={},warmup_ratio=0.0,warmup_steps=0,log_level='passive',log_level_replica='warning',log_on_each_node=True,logging_dir='trainer_output\\runs\\Jun14_11-14-00_LAPTOP-PJP6MGE1',logging_strategy=<IntervalStrategy.STEPS: 'steps'>,logging_first_step=False,logging_steps=500,logging_nan_inf_filter=True,save_strategy=<SaveStrategy.STEPS: 'steps'>,save_steps=500,save_total_limit=None,save_safetensors=True,save_on_each_node=False,save_only_model=False,restore_callback_states_from_checkpoint=False,no_cuda=False,use_cpu=False,use_mps_device=False,seed=42,data_seed=None,jit_mode_eval=False,use_ipex=False,bf16=False,fp16=False,fp16_opt_level='O1',half_precision_backend='auto',bf16_full_eval=False,fp16_full_eval=False,tf32=None,local_rank=0,ddp_backend=None,tpu_num_cores=None,tpu_metrics_debug=False,debug=[],dataloader_drop_last=False,eval_steps=None,dataloader_num_workers=0,dataloader_prefetch_factor=None,past_index=-1,run_name='trainer_output',disable_tqdm=False,remove_unused_columns=True,label_names=None,load_best_model_at_end=False,metric_for_best_model=None,greater_is_better=None,ignore_data_skip=False,fsdp=[],fsdp_min_num_params=0,fsdp_config={'min_num_params': 0,'xla': False,'xla_fsdp_v2': False,'xla_fsdp_grad_ckpt': False},fsdp_transformer_layer_cls_to_wrap=None,accelerator_config=AcceleratorConfig(split_batches=False,dispatch_batches=None,even_batches=True,use_seedable_sampler=True,non_blocking=False,gradient_accumulation_kwargs=None,use_configured_state=False),deepspeed=None,label_smoothing_factor=0.0,optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>,optim_args=None,adafactor=False,group_by_length=False,length_column_name='length',report_to=['wandb'],ddp_find_unused_parameters=None,ddp_bucket_cap_mb=None,ddp_broadcast_buffers=None,dataloader_pin_memory=True,dataloader_persistent_workers=False,skip_memory_metrics=True,use_legacy_prediction_loop=False,push_to_hub=False,resume_from_checkpoint=None,hub_model_id=None,hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>,hub_token=None,hub_private_repo=None,hub_always_push=False,gradient_checkpointing=False,gradient_checkpointing_kwargs=None,include_inputs_for_metrics=False,include_for_metrics=[],eval_do_concat_batches=True,fp16_backend='auto',push_to_hub_model_id=None,push_to_hub_organization=None,push_to_hub_token=None,mp_parameters='',auto_find_batch_size=False,full_determinism=False,torchdynamo=None,ray_scope='last',ddp_timeout=1800,torch_compile=False,torch_compile_backend=None,torch_compile_mode=None,include_tokens_per_second=False,include_num_input_tokens_seen=False,neftune_noise_alpha=None,optim_target_modules=None,batch_eval_metrics=False,eval_on_start=False,use_liger_kernel=False,eval_use_gather_object=False,average_tokens_across_devices=False,dataset_num_proc=None,num_mini_batches=1,total_episodes=None,local_rollout_forward_batch_size=64,num_sample_generations=10,response_length=53,stop_token=None,stop_token_id=None,temperature=0.7,missing_eos_penalty=None,sft_model_path='EleutherAI/pythia-160m',world_size=None,num_total_batches=None,micro_batch_size=None,local_batch_size=None,batch_size=None,local_mini_batch_size=None,mini_batch_size=None,exp_name='ppo_config',reward_model_path='EleutherAI/pythia-160m',model_adapter_name=None,ref_adapter_name=None,num_ppo_epochs=4,whiten_rewards=False,kl_coef=0.05,kl_estimator='k1',cliprange=0.2,vf_coef=0.1,cliprange_value=0.2,gamma=1.0,lam=0.95,ds3_gather_for_generation=True
)
  • model_args
ModelConfig(model_name_or_path=None, model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=False, lora_r=16, lora_alpha=32, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, lora_task_type='CAUSAL_LM', use_rslora=False, use_dora=False, load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False,
)

其中:

  • script_args:主要控制数据集的参数。特别地,dataset_name是required,但这个事情比较奇怪

    • 如果是在Linux上使用脚本启动的话是必须传入--dataset_name或者--dataset-name参数的,否则会显示 error: the following arguments are required: --dataset_name/--dataset-nametransformers==4.51.3, trl==0.17.0
    • 在Windows系统上即使不传入任何参数也是不会报错的(transformers==4.52.4, trl==0.18.1,比Linux上的版本要高一点,但应该不是版本问题)。
  • training_args:即PPOConfig的对象,对应PPOTrainer中的args参数,其中包含了非常多与训练参数,具体每个训练器及其配置的参数列表可以直接到HuggingFace上的TRL文档查看(https://huggingface.co/docs/trl),这里例举常用的参数:

    • output_dir:模型checkpoint导出的路径,
    • adam_xxx:Adam优化器相关的参数,似乎TRL里所有Trainer的默认优化器都是Adam,不过是可以在PPOTrainer的参数中定义优化器的(PPOTrainer的优化器参数optimizers是需要传入两个变量(optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR])),一个是优化器,另一个是学习率的规划器。
    • learning_rate:学习率,
    • kl_coef:KL散度的惩罚系数,越大相当于策略更新幅度越小,越小策略就可能更新幅度越快,
  • model_args对应的是PPOTrainerpeft_config参数,里面也有一些和模型初始化相关的参数,包括量化与PEFT相关的参数也是从中提取得到的,比如在base_pipeline脚本中:

...
# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. 
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. 
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
#   - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
#   - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
#   - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
#   - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
#   - keyword arguments for `DPOTrainer`: e.g. "ref_model"
#   - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()script_arguments = update_trl_config(script_arguments, **config_kwargs)trainer_config = update_trl_config(trainer_config, **config_kwargs)model_config = update_trl_config(model_config, **config_kwargs)peft_config = get_peft_config(model_config)quantization_config = get_quantization_config(model_config)...

peft_configquantization_config都是可以现成的函数生成的,它们被定义在trl项目根目录下的trainer/utils.py中:

def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:if model_args.load_in_4bit:quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=model_args.torch_dtype,  # For consistency with model weights, we use the same value as `torch_dtype`bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,bnb_4bit_quant_storage=model_args.torch_dtype,)elif model_args.load_in_8bit:quantization_config = BitsAndBytesConfig(load_in_8bit=True,)else:quantization_config = Nonereturn quantization_configdef get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":if model_args.use_peft is False:return Noneif not is_peft_available():raise ValueError("You need to have PEFT library installed in your environment, make sure to install `peft`. ""Make sure to run `pip install -U peft`.")peft_config = LoraConfig(task_type=model_args.lora_task_type,r=model_args.lora_r,target_modules=model_args.lora_target_modules,lora_alpha=model_args.lora_alpha,lora_dropout=model_args.lora_dropout,bias="none",use_rslora=model_args.use_rslora,use_dora=model_args.use_dora,modules_to_save=model_args.lora_modules_to_save,)return peft_config

从上面的源码可以看出,如果model_config.use_peft=False,则默认不启用PEFT,得到的peft_config也就是None,同理是否采用量化取决于load_in_4bitload_in_8bit是否至少有一个是True


1.2 加载分词器与模型(tokenizer & model)

回到base_pipeline中对应的部分:

	...# 2 Load models and tokenizerlogging.info("Load models and tokenizer ...")logging.info(f"  - Model: {model_config.model_name_or_path}")tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not "pad_token" in tokenizer.special_tokens_map:tokenizer.add_special_tokens({"pad_token": "[PAD]"})if tokenizer.chat_template is None:tokenizer.chat_template = SIMPLE_CHAT_TEMPLATEmodel = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map = "auto",trust_remote_code = model_config.trust_remote_code,quantization_config = quantization_config,)if peft_config is not None:logging.info("Prepare model for PEFT ...")model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()model = get_peft_model(model, peft_config)...

这里主要是几个零碎的注意点:

  1. 分词器的词汇表中一定要包含pad_token,具体可以通过tokenizer.special_tokens_map,否则在后面trainer.train()时会报错,提示你设置pad_token
  2. 关于tokenizer.chat_template的问题,这个到接下来的数据处理部分会详细说明,针对对话类的任务,一般是会采用chat_template来处理数据集的,因此可以用trl.trainer.utils中提供的SIMPLE_CHAT_TEMPLATE作为缺省
  3. 模型可以通过量化加载(即quantization_config不为None时),这样会节约内存,此时一般都会做如下的设置👇(如果不是量化加载的话,可能加上这些配置也没什么关系,所以我就这么写了)
    • 这里的pretraining_tp是张量并行的意思,数值越大,就会并行的数量越多,当然只有多卡才能并行,单卡只能设为1
    • use_cache应该是指(大概)做model.generate()生成时是否进行kv_cache缓存
    • model.enable_input_require_grads()也是节约内存的方法,我记得是缓存梯度以达到更快的反向传播运算的作用。
		...model.config.pretraining_tp = 1model.config.use_cache = Falsemodel.gradient_checkpointing_enable()# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)# Then you need set `model.enable_input_require_grads()` yourself# model = prepare_model_for_kbit_training(model)model.enable_input_require_grads()...

然后对于PPOTrainer而言,非常特殊,它的构造参数还需要reward_model, value_model, ref_model三个参数,因此在base_pipeline中做了额外处理:

	if name == "PPO":logging.info("PPO load reward value and reference models ...")# PPO is special! It needs more components!logging.info(f"  - Reward model: {trainer_config.reward_model_path}")reward_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)value_model = AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code = model_config.trust_remote_code,num_labels = 1,)logging.info("  - Copy reference model ...")ref_model = deepcopy(model)# ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs["reward_model"] = reward_modeltrainer_kwargs["value_model"] = value_modeltrainer_kwargs["ref_model"] = ref_modellogging.info("  - Done!")

这里注意一下:

  • 一般初始化ref_model和目标模型是完全一样的,这里测试了两种方法,一种是直接ref_model = deepcopy(model),这确实是可行的,不会发生unpickled之类的问题,另一种方法则是重构一个和目标模型结构相同的空模型,然后加载状态字典:
      # ref_model = model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())
    
    这个方法在use_peft=False时是可行的,但是因为这里可能是一个peft_model,因此model.__class__的参数还需要一个peft_config,直接这样写是会报错的
  • 使用HfArgumentParser加载PPOConfig参数时,ppo_config.reward_model_path是有默认值的(可以看上面的training_args),即'EleutherAI/pythia-160m',这个奖励模型必须是可以被AutoModelForSequenceClassification类型加载的
  • value_model通常设置为和reward_model相同,至少在官方的PPO示例中是这样的,其实也可以解释,我的理解是:
    A ( s , a ) = Q ( s , a ) − V ( s ) A(s,a)=Q(s,a)-V(s) A(s,a)=Q(s,a)V(s)
    中,这里 V ( s ) V(s) V(s)value_model Q ( s , a ) = r ( s , a ) + γ V ( s ′ ) Q(s,a)=r(s,a)+\gamma V(s') Q(s,a)=r(s,a)+γV(s)中的 r ( s , a ) r(s,a) r(s,a)即奖励函数,在CAUSAL_LM的语境下,其实就是句子 s s s加了一个单词 a a a后的奖励,本质上都是对一句话进行评分。虽然说得通,但是感觉也挺奇怪的。

1.3 数据集分割与字段问题(dataset)

目前主流的数据集加载都是直接用HuggingFace的datasets.load_dataset方法加载,标准的数据格式就是每个样本的格式为{“column_1”: <data_1>, "column_2": <data_2>},整个数据集类似jsonl的格式。

针对目前已知的几个训练器,它们对数据集的格式要求大概是这样的:

  • SFTTrainer:字段要包含promptcompletion,注意,训练时只会在completion部分计算损失,prompt部分是不会计算损失的

    • 但是其实我发现如果数据集中只有text或者input+target这种字段时也是可行的,需要仔细读源码才行。
  • DPOTrainer:这个就很简单,一般就是prompt+chosen+rejected,但是在官方示例使用的数据集trl-lib/ultrafeedback_binarized中,字段却是chosen, rejected, score_chosen, score_rejected,也就是说其实prompt可能并非必须,本来也确实可以直接用空字符串替代,然后数据集中如果没有回答进行评分的话,可能还有一个默认的评分机制在训练器里面。

  • GRPOTrainer:这个我也是根据官方示例来的,我看到它使用的数据集是trl-lib/tldr,这是一个典型的只有promptcompletion两个字段的数据集

  • PPOTrainer这个最为特殊,之前DeepSeek一直写不对的原因也在于此

    • 在官方给出的PPO运行脚本(ppo.py)中

          def prepare_dataset(dataset, tokenizer):"""pre-tokenize the dataset before training; only collate during training"""def tokenize(element):outputs = tokenizer(element[dataset_text_field],padding=False,)return {"input_ids": outputs["input_ids"]}return dataset.map(tokenize,batched=True,remove_columns=dataset.column_names,num_proc=training_args.dataset_num_proc,)
      

      注意到这里是直接处理成分词后的input_ids格式

    • 但是,我本来以为PPO和GRPO是类似的,因此理论上在数据集的格式要求上应该也差不了太多,因此我觉得可能prompt+completion也是可行的,但是测试下来会报错,报错提示是一定要求是带input_ids字段的。

    • 因此我单独给PPO写了一个数据处理的data_processor,而其余的都是很简单的:

      # @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
      def generate_simple_data_processor(name, **kwargs):if name in ["SFT", "GRPO"]:def _data_processor(_data):return {"prompt": _data["prompt"], "completion": _data["completion"]}elif name == "PPO":tokenizer = kwargs.get("tokenizer")def _data_processor(_data):outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)return {"input_ids": outputs["input_ids"]}elif name == "DPO":def _data_processor(_data):return {"prompt": _data["prompt"], "chosen": _data["chosen"], "rejected": _data["rejected"]}else:raise NotImplementedError(name)return _data_processor
      

当然一些写法也会使用apply_chat_template,这通常是用于交互式对话任务,都是可行的:

		def _data_processor(_data):_message = [{"role": "system", "content": "You are an AI assistant developped by CY"},{"role": "user", "content": _data[dataset_input_column]},{"role": "assistant", "content": _data[dataset_target_column]},]_prompt = tokenizer.apply_chat_template(_message, tokenize=False)return {"text": _prompt}

然后一个小细节是在用dataset.map进行数据处理时一般会设置参数emove_columns=dataset.column_names,以丢弃原先不必要的字段,防止Trainer在运行时错用了其他的字段进行训练。

1.4 训练与保存(checkpoint)

base_pipeline中最后一部分训练完然后保存模型,顺利地话就完整地跑通:

	# 4 Train modellogging.info("Trainer starts ...")trainer = TRLTrainer(model = model,args = trainer_config,train_dataset = train_dataset,eval_dataset = eval_dataset,processing_class = tokenizer,peft_config = peft_config,**trainer_kwargs)trainer.train()logging.info("  - Trainer finishes!")# 5 Save modelif trainer_config.push_to_hub:logging.info(f"  - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")trainer.push_to_hub()logging.info(f"Save model to {trainer_config.output_dir}")trainer.save_model(trainer_config.output_dir)

这里值得注意的几个问题:

  1. trainer_config(即SFTConfig, PPOConfig, DPOConfig, GRPOConfig之类的对象)中,有一个参数trainer_config.report_to,这个不设置的话默认是会上传到W&B的,网络不支持访问的话是会在训练到checkpoint的时候发生网络错误的,因此一般会设置成"none",或者也可以用"tensorboard"

    • 不过就训练绘图的话,训练结束到trainer_config.output_dir中找到checkpoint-xxx文件夹,里面会有trainer_state.json文件,然后自己根据里面的数据绘图即可,例如:
      # Plot dynamics of TRL trainer state
      def plot_trl_dynamics(trainer_state_path):with open(trainer_state_path, 'r', encoding="utf8") as f:data = json.load(f)log_history = data["log_history"]steps = [entry["step"] for entry in log_history]episodes = [entry["episode"] for entry in log_history]epochs = [entry["epoch"] for entry in log_history]policy_loss = [entry["loss/policy_avg"] for entry in log_history]value_loss = [entry["loss/value_avg"] for entry in log_history]lrs = [entry["lr"] for entry in log_history]entropys = [entry["objective/entropy"] for entry in log_history]kls = [entry["objective/kl"] for entry in log_history]non_score_rewards = [entry["objective/non_score_reward"] for entry in log_history]rlhf_rewards = [entry["objective/rlhf_reward"] for entry in log_history]scores = [entry["objective/scores"] for entry in log_history]plt.figure(figsize=(8, 8))ax_1 = plt.subplot(2, 2, 1)ax_2 = plt.subplot(4, 2, 2)ax_3 = plt.subplot(4, 2, 4)ax_4 = plt.subplot(2, 2, 3)ax_5 = plt.subplot(2, 2, 4)ax_1.plot(steps, policy_loss, label="Policy Loss")ax_1.plot(steps, value_loss, label="Value Loss", linestyle="--")ax_1.set_xlabel("Step"), ax_1.set_ylabel("Loss"), ax_1.legend()ax_1.set_title("Policy and Value Loss")# ------------------------------------------------------------------ax_2.plot(steps, kls, label="objective/kl")ax_2.set_xlabel("Step"), ax_2.set_ylabel("KL"), ax_2.legend()ax_2.set_title("KL Curve")# ------------------------------------------------------------------ax_3.plot(steps, entropys, label="objective/entropy")ax_3.set_xlabel("Step"), ax_3.set_ylabel("Entropy"), ax_3.legend()ax_3.set_title("Entropy Curve")# ------------------------------------------------------------------ax_4.plot(steps, lrs, label="Learning Rate")ax_4.set_xlabel("Step"), ax_4.set_ylabel("Learning Rate"), ax_4.legend()ax_4.set_title("Learning Rate Curve")# ------------------------------------------------------------------ax_5.plot(steps, non_score_rewards, label="objective/non_score_reward", linestyle="--")ax_5.plot(steps, rlhf_rewards, label="objective/rlhf_reward", linestyle="--")ax_5.plot(steps, scores, label="objective/scores")ax_5.set_xlabel("Step"), ax_5.set_ylabel("Score/Reward"), ax_5.legend()ax_5.set_title("Reward and Score")plt.show()print(steps)print('-' * 64)print(episodes)print('-' * 64)print(epochs)
      
  2. 关于W&B的使用:如果网络支持,使用wandb确实也是不错的选择:

    登录 wandb

    方法 1:交互式登录(推荐)

    运行以下代码,会弹出浏览器或要求输入 API Key:

    import wandb
    wandb.login()
    

    然后按提示操作即可。

    方法 2:直接提供 API Key

    import wandb
    wandb.login(key="your_api_key_here")  # 替换成你的 wandb API Key
    

    ⚠️ 注意:不要直接在代码里硬编码 API Key(特别是公开代码时),建议使用环境变量或配置文件。

    方法 3:通过环境变量设置(适合脚本/服务器)

    export WANDB_API_KEY=your_api_key_here
    

    然后在 Python 中直接 wandb.init() 即可。


    步骤 3:在 Trainer 中启用 wandb

    方式 1:自动记录(默认)

    HuggingFace Trainer 默认会使用 wandb(如果已安装并登录):

    from transformers import Trainer, TrainingArguments
    training_args = TrainingArguments(output_dir="./results",report_to="wandb",  # 明确指定使用 wandb(默认已经是这样)logging_dir="./logs",# 其他训练参数...
    )
    trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,
    )
    trainer.train()  # 自动记录到 wandb
    

    方式 2:手动初始化 wandb(更灵活控制)

    import wandb
    # 初始化 wandb 并设置项目名称
    wandb.init(project="my-ppo-project", name="experiment-1")
    # 然后正常训练
    trainer.train()
    # 训练结束后(可选)
    wandb.finish()
    

    这样你可以自定义 wandb 的配置(如 projectnameconfig 等)。


额外配置(可选)

1. 记录超参数

config = {"learning_rate": 2e-5,"batch_size": 32,"epochs": 3,
}
wandb.init(project="my-project", config=config)

训练时,这些参数会显示在 wandb 面板中。

2. 自定义记录指标

# 在训练循环中记录自定义指标
wandb.log({"reward": current_reward, "loss": current_loss})

3. 恢复中断的训练

wandb.init(id="之前运行的ID", resume="must")
  1. 训练开始后,终端会显示类似:
    Tracking run with wandb, view at https://wandb.ai/your-username/your-project/runs/xxx
    
  2. 打开链接即可查看实时训练数据。

1.5 单元测试

单元测试脚本trainer_pipelines.py:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: caoyang@stu.sufe.edu.cnimport os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = ["Qwen/Qwen2.5-0.5B-Instruct","EleutherAI/pythia-1b-deduped","EleutherAI/pythia-160m",
]dataset_names = ["trl-lib/tldr",	# train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]"trl-lib/ultrafeedback_binarized",	# train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]def sft_pipeline_test():logging.info("SFT unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info("PPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[1])"""EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p=0.0, inplace=False)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(post_attention_dropout): Dropout(p=0.0, inplace=False)(post_mlp_dropout): Dropout(p=0.0, inplace=False)(attention): GPTNeoXAttention((query_key_value): Linear(in_features=2048, out_features=6144, bias=True)(dense): Linear(in_features=2048, out_features=2048, bias=True))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features=2048, out_features=50304, bias=False))"""dataset_name = os.path.join(dataset_home, dataset_names[0])reward_model_path = os.path.join(model_home, model_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"reward_model_path": reward_model_path,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["query_key_value"],}trainer_kwargs = {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info("DPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[2])data_processor = Noneconfig_kwargs = {"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "descriptiveness[:500]","dataset_test_split": "descriptiveness[500:600]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info("GRPO unittest ...")model_name_or_path = os.path.join(model_home, model_names[0])dataset_name = os.path.join(dataset_home, dataset_names[0])data_processor = Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs = {"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}","model_name_or_path": model_name_or_path,"dataset_name": dataset_name,"trust_remote_code": True,"dataset_train_split": "train[:500]","dataset_test_split": "validation[:100]","use_peft": True,"report_to": "none","lora_target_modules": ["q_proj", "k_proj", "v_proj"]}trainer_kwargs = {"reward_funcs": reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

这里主要就是grpo_pipeline_test中的reward_funcs,其他几个测试函数都没什么好说的,很好理解

GRPOTrainer中有required参数reward_funcs,这个在TRL文档中给出的例子是:

	def reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]

这个奖励的意思是会给长回答更高的得分。

注意这里的**kwargs是不可以省略的,因为GRPOTrainer调用reward_funcs会传入其他参数,比如prompts或者ground_truths,在TRL的GRPO文档中有很多例子。


2 一些其他的问题

这部分摘取一些和DeepSeek讨论的问题

2.1 DataProcessor与DataCollator

关于TRL的Trainer的一个问题,因为Trainer是可以传入data_collator参数的,但是一般来说我们都会在向Trainer传入数据集(如train_dataset)之前,先写好一个函数用于预处理数据集,比如train_dataset = train.dataset.map(prompt_formatter)这样,我的理解其实这里prompt_formatter何data_collator的作用是一样的(我觉得它们处理数据的逻辑也是一样,就是传入一个样本,然后做数据处理),因此如果在Trainer外面已经做了train_dataset.map(prompt_formatter)的操作,就不需要再传入data_collator了。反之也是一样,我可以传入data_collator,而无需进行train_dataset.map(prompt_formatter)的操作

以下回答由AI生成:


你的理解基本正确,但有一些细节需要注意。prompt_formatter(在dataset.map中使用)和data_collator确实都涉及数据处理,但它们的作用时机和目的有所不同。下面我来详细解释它们的区别以及何时使用它们:

2.1.1 dataset.map(prompt_formatter):

  • 作用时机: 在数据加载时,预处理阶段(即在训练开始之前)对数据集中的每个样本进行处理。
  • 典型用途:
    • 将原始文本转换为模型需要的输入格式(例如添加特殊token、构建prompt-completion结构)。
    • 对数据进行过滤(通过remove_columns删除不需要的字段)。
    • 执行一次性的、确定性的转换(例如tokenization,如果你不启用动态padding)。
  • 特点:
    • 处理结果是静态的(处理后的数据会被缓存,除非设置load_from_cache_file=False)。
    • 适合所有样本共享的逻辑(例如统一的prompt模板)。

2.1.2 data_collator:

  • 作用时机: 在训练过程中,动态批处理阶段(即在每个batch生成时)对样本进行实时处理。
  • 典型用途:
    • 动态padding(将同一batch内的样本padding到相同长度)。
    • 随机数据增强(例如随机masking,适用于MLM任务)。
    • 构造特殊输入(例如为RLHF构建奖励模型的输入格式)。
  • 特点:
    • 处理是动态的(每次batch采样时重新执行)。
    • 适合需要随机性依赖batch内其他样本的操作(例如padding长度取决于当前batch)。

2.1.3 何时选择哪种方式?

  • 只用dataset.map:

    • 如果你的数据处理是确定性的、不需要动态调整(例如固定的prompt模板),且不需要batch级别的操作(如padding),可以只使用dataset.map
    • 但注意:如果涉及变长序列,仍需data_collator处理padding(除非你已手动填充到固定长度)。
  • 只用data_collator:

    • 如果数据处理需要动态性(例如随机masking)或依赖batch上下文(如padding),可以只用data_collator
    • 但注意:这样可能导致训练时重复处理数据(例如每次epoch重复tokenization),影响效率。
  • 联合使用:

    • 常见做法是:
      1. dataset.map处理静态逻辑(例如tokenization、添加固定prompt)。
      2. data_collator处理动态逻辑(例如padding、随机增强)。
    • 例如:
      def prompt_formatter(example):example["input_ids"] = tokenizer(example["text"]).input_idsreturn example  # 不padding,保留原始长度train_dataset = raw_dataset.map(prompt_formatter, batched=False)data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True)  # 动态padding和masking
      
  • 如果prompt_formatter已经完成了所有必要处理(包括padding到固定长度),且不需要任何动态逻辑,确实可以跳过data_collator(或传入default_data_collator,它仅做简单的batch堆叠)。

  • 但如果需要动态操作(如padding),即使有dataset.map,仍需data_collator

  • 二者并非互斥,而是互补的。dataset.map用于离线预处理,data_collator用于在线批处理。

  • 选择取决于你的数据处理需求是否依赖动态上下文或随机性。


2.2 关于PartialState

accelerate库里有一个方法:PartialState().local_main_process_first(),我看到这个方法会用来加载数据集,这个看起来像是并行加载数据集的意思,你可以解释一下这个函数是用来做什么的吗?

注:这个方法本来是在PPO官方示例中使用的


PartialState().local_main_process_first() 是 Hugging Face accelerate 库中用于多进程并行处理的一个关键方法,尤其在分布式训练或推理时控制数据加载、文件处理等操作的执行顺序。它的核心作用是确保某些操作(如数据集加载或文件预处理)只在**主进程(main process)**上优先执行,避免多进程竞争或重复操作。


2.2.1 核心功能

  • local_main_process_first() 返回一个上下文管理器(context manager),其行为如下:
    • 主进程(local main process):会立即执行上下文内的代码块。
    • 其他进程(非主进程):会等待主进程完成后,再跳过或执行相同操作(取决于具体逻辑)。
  • 目的:避免多进程同时操作共享资源(如文件读取、数据集缓存)导致的冲突或冗余计算。

2.2.2 典型使用场景

(1) 加载数据集时避免重复缓存

在分布式训练中,如果多个进程同时尝试缓存数据集(例如通过 datasets 库),可能会导致文件写入冲突。通过 local_main_process_first() 可以确保只有主进程生成缓存,其他进程直接复用:

from accelerate import PartialStatestate = PartialState()  # 初始化分布式状态# 只有主进程先缓存数据集,其他进程等待后直接加载缓存
with state.local_main_process_first():dataset = load_dataset("imdb")  # 假设这是耗时操作

(2) 文件预处理或下载

如果预处理(如分词、数据转换)只需要执行一次:

with state.local_main_process_first():tokenized_data = dataset.map(tokenize_function, batched=True)  # 主进程先处理
# 其他进程会自动等待主进程完成,然后直接使用处理后的数据

2.2.3 参数与底层机制

  • PartialState:是 accelerate 中管理分布式状态的类,自动检测当前进程的分布式环境(如多GPU、TPU等)。
  • local_main_process:指当前节点(机器)上的主进程(通常是全局主进程的一个子集,在多机训练时每台机器有自己的 local_main_process)。
  • 同步机制:底层通过 torch.distributed 或类似库实现进程间同步,确保非主进程等待主进程完成。

2.2.4 与类似方法的区别

方法/属性作用范围用途
local_main_process_first()当前节点的主进程优先避免文件/缓存操作的竞争
main_process_first()全局主进程优先(跨所有节点)适用于全局只需执行一次的操作
on_main_process()仅在主进程执行完全跳过其他进程(如日志记录)

2.2.5 完整示例(分布式训练中的数据加载)

from accelerate import PartialState
from datasets import load_dataset# 初始化分布式状态(自动检测环境)
state = PartialState()# 只有主进程优先加载并缓存数据,其他进程等待
with state.local_main_process_first():dataset = load_dataset("glue", "mrpc", split="train")tokenized_data = dataset.map(lambda x: tokenizer(x["sentence"]), batched=True)# 所有进程继续同步训练
for batch in tokenized_data:outputs = model(**batch)...

2.2.6 注意事项

  1. 缓存路径一致性:确保所有进程能访问相同的缓存路径(如共享存储)。
  2. 性能影响:主进程的预处理时间会被计入总耗时,但避免了重复操作的开销。
  3. 错误处理:如果主进程预处理失败,其他进程会一直等待,需添加超时或错误处理逻辑。

2.2.7 总结

  • 何时使用:在分布式环境中,任何需要单例执行的操作(如数据缓存、文件下载、预处理)都应使用 local_main_process_first()
  • 优势:减少资源竞争,避免冗余计算,确保分布式任务的高效协同。

如果需要进一步了解 accelerate 的分布式策略(如多机训练),可以继续探讨!


2.3 DPO和PPO的reference_model的区别

这个是阅读TRL文档发现的一个问题,DPOTrainerref_model参数是有默认值None的,但是PPOTrainerref_model参数没有默认值。这意味着必须给PPOTrainer传入ref_model参数,尽管一般都是跟目标模型一样。


你提出了一个很好的问题!DPO(Direct Preference Optimization)虽然确实不属于传统强化学习的范畴,但它仍然保留了reference model,这是出于特定的设计考虑。下面我详细解释原因:

2.3.1 DPO的核心思想与Reference Model的作用

DPO通过直接优化偏好数据来规避强化学习,但其目标函数(源自Bradley-Terry模型)隐式地包含了一个KL散度约束:
L DPO = − E [ log ⁡ σ ( β log ⁡ π θ ( y w ∣ x ) π ref ( y w ∣ x ) − β log ⁡ π θ ( y l ∣ x ) π ref ( y l ∣ x ) ) ] \mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] LDPO=E[logσ(βlogπref(ywx)πθ(ywx)βlogπref(ylx)πθ(ylx))]
其中:

  • π θ \pi_\theta πθ 是待训练的策略(当前模型)。
  • π ref \pi_{\text{ref}} πrefreference model(通常初始化为预训练模型,例如SFT模型)。
  • β \beta β 是控制偏离参考模型程度的超参数。

Reference Model的作用

  • 防止过度偏离初始策略:KL散度约束(通过reference model体现)避免模型为迎合偏好数据而退化(例如生成极端或无意义的输出)。
  • 隐式正则化:类似于PPO中的KL惩罚,但DPO通过直接在损失函数中对比当前模型和reference模型的输出概率实现。

2.3.2 为什么DPO需要Reference Model,尽管它不是RL?

  • 数学推导的必然性:DPO的公式是从逆向强化学习的目标推导而来的,其中reference model代表初始策略(即RL中的 π old \pi_{\text{old}} πold)。去掉reference model会导致目标函数失去约束,模型可能过拟合偏好数据。
  • 实践中的稳定性:实验表明,去掉reference model后,模型容易过度优化偏好信号(例如始终输出短答案或重复模式)。

2.3.3 TRL的DPOTrainer中的ref_model

在HuggingFace TRL库中,DPOTrainer确实需要传入ref_model参数,原因如下:

  1. 计算概率比:DPO需要比较当前模型和reference模型对正/负样本的概率(即(\log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}))。
  2. 冻结参考模型ref_model在训练过程中是冻结的(不更新),仅作为基准。
  3. 内存优化:实际实现中,ref_model可能与当前模型共享部分结构(例如共享底层Transformer),但通过梯度截断避免更新。

2.3.4 DPO vs PPO的Reference Model

特性PPO (RL)DPO (Non-RL)
Reference Model旧策略 π old \pi_{\text{old}} πold预训练模型 π ref \pi_{\text{ref}} πref
更新方式定期同步(on-policy要求)完全冻结
作用重要性采样校正 + KL约束隐式KL约束 + 防止过拟合

2.3.5 如果没有Reference Model会怎样?

假设令 π ref \pi_{\text{ref}} πref为均匀分布,则DPO退化为直接最大化偏好对的似然,可能导致:

  • 模型过度自信(overconfident)。
  • 生成内容多样性下降(模式坍塌)。
  • 对噪声偏好数据过拟合。

DPO虽然跳过了显式的强化学习框架,但通过保留reference model,继承了RL中“约束策略更新”的核心思想。这是其稳定性和性能的关键设计,也是它区别于纯监督学习(如直接微调)的重要特征。

PS:我感觉这个回答没有解释我的问题,答非所问了

http://www.xdnf.cn/news/14232.html

相关文章:

  • 自用方案总结
  • 如何在 Elementary OS 上安装 Snap Store
  • Thymeleaf功能标签th:block
  • java面试总结-20250605
  • 5G核心网Non-IP数据报文转发机制:Unstructured会话与协议栈解析
  • 抖音 pc + 翻页
  • C#最佳实践:推荐使用泛型而非装箱、拆箱
  • 60、数据访问-数据库场景的自动配置分析与整合测试
  • c++26新功能—契约编程
  • 单测时如何让 mock 的接口在长链路调用时一直生效
  • 从STM32到NXP:GPIO就像装修房子,多了个“智能开关”
  • 基于 SpringBoot+Servlet+JSP 的医院医保管理系统的设计与实现,论文7000字,可根据实际情况调整
  • ES+索引库文档操作
  • [CVPR 2025] DiCo:动态协作网络助力半监督3D血管分割新突破
  • AI Agent实战 - LangChain+Playwright构建火车票查询Agent
  • 人工智能学习28-BP过拟合
  • [k8s]--exec探针详细解析
  • java常见第三方依赖以及相关安全问题
  • http1.x VS http2.x 协议
  • Spring Cloud Alibaba 中间件
  • 硬编码(修改RIP相关指令)
  • HTML+CSS 半透明登录框
  • (LeetCode每日一题) 2566. 替换一个数字后的最大差值 ( 贪心 )
  • 安防市场的中小企业突围——从竞品分析到破局路径的思考
  • Spring Boot中Controller层规划与最佳实践详解
  • 【北京迅为】iTOP-4412精英版使用手册-第二十一章 延时函数专题
  • Python爬虫-批量爬取快手视频并将视频下载保存到本地
  • BeckHoff PLC --> 料筐(KLT Box)自动对中与抓取程序分析
  • Deep Research Agent的深度与广度如何保证
  • OSGI 是什么,有哪些具体应用、java8、9、10、11比较