当前位置: 首页 > news >正文

transformers 的Trainer的用法

transformers 库中的 Trainer 类是一个高级 API,它简化了训练和评估 transformer 模型的流程。下面我将从核心概念、基本用法到高级技巧进行全面讲解:

​​1. 核心功能​​

Trainer 自动处理以下任务:

  • ​​训练循环​​:自动实现 epoch 迭代、批次加载
  • ​​优化器&学习率调度​​:内置 AdamW 并支持自定义
  • ​​分布式训练​​:自动支持单机多卡(DataParallel/DistributedDataParallel)
  • ​​混合精度训练​​:通过 fp16=True 启用
  • 日志记录​​:集成 TensorBoard、Weights & Biases 等
  • ​​模型保存​​:定期保存检查点 + 最佳模型保存
  • ​​评估指标计算​​:自动计算验证集指标

​​2. 基础使用步骤​​

步骤 1:准备组件​​

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset# 加载模型
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")# 加载数据集
dataset = load_dataset("glue", "mrpc")
train_set = dataset["train"]
eval_set = dataset["validation"]# 加载分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")# 数据预处理函数
def tokenize(batch):return tokenizer(batch["sentence1"], batch["sentence2"], padding="max_length", truncation=True)train_set = train_set.map(tokenize, batched=True)
eval_set = eval_set.map(tokenize, batched=True)

步骤 2:配置训练参数​​

training_args = TrainingArguments(output_dir="./results",             # 输出目录num_train_epochs=3,                 # 训练轮数per_device_train_batch_size=16,     # 每个GPU的批次大小per_device_eval_batch_size=64,      # 验证批次大小evaluation_strategy="epoch",        # 每个epoch后评估save_strategy="epoch",              # 每个epoch后保存模型learning_rate=2e-5,                 # 学习率weight_decay=0.01,                  # 权重衰减fp16=True,                          # 混合精度训练logging_dir="./logs",               # 日志目录report_to="tensorboard",            # 日志工具load_best_model_at_end=True,        # 训练结束加载最佳模型metric_for_best_model="accuracy",   # 最佳模型指标
)

步骤 3:自定义评估指标​​

import numpy as np
from sklearn.metrics import accuracy_scoredef compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)acc = accuracy_score(labels, predictions)return {"accuracy": acc}

步骤 4:实例化 Trainer​​

trainer = Trainer(model=model,args=training_args,train_dataset=train_set,eval_dataset=eval_set,compute_metrics=compute_metrics,  # 自定义评估函数
)

步骤 5:开始训练​​

trainer.train()

步骤 6:评估与预测​​

results = trainer.evaluate()
print(results)
# 对新数据预测
test_inputs = tokenizer("Hello world!", return_tensors="pt")
predictions = trainer.predict(test_inputs)

​​3. 高级特性与技巧​​

3.1 自定义优化器​​

from torch.optim import AdamW
from transformers import get_scheduleroptimizer = AdamW(model.parameters(), lr=2e-5)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=1000
)trainer = Trainer(optimizers=(optimizer, lr_scheduler),... # 其他参数
)

3.2 ​​回调函数(Callbacks)​​

from transformers import TrainerCallbackclass CustomCallback(TrainerCallback):def on_epoch_end(self, args, state, control, **kwargs):print(f"Epoch {state.epoch} finished!")trainer.add_callback(CustomCallback())

3.3 恢复训练​​

trainer.train(resume_from_checkpoint=True)  # 自动加载最新检查点
# 或指定具体路径
# trainer.train(resume_from_checkpoint="/path/to/checkpoint")

3.4 梯度累积与裁剪​​

在 TrainingArguments 中设置:

gradient_accumulation_steps=4,  # 每4个步骤更新一次权重
max_grad_norm=1.0               # 梯度裁剪阈值

3.5 ​​类权重(不平衡数据)​​

from torch import nn
class WeightedTrainer(Trainer):def compute_loss(self, model, inputs, return_outputs=False):labels = inputs.pop("labels")outputs = model(**inputs)loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0]))  # 类权重loss = loss_fct(outputs.logits, labels)return (loss, outputs) if return_outputs else loss

​​4. 常见问题解决​​

4.1 内存不足(OOM)​​

  • 降低 per_device_train_batch_size
  • 启用梯度累积 (gradient_accumulation_steps)
  • 使用 fp16 混合精度
  • 尝试梯度检查点(在模型配置中设置 gradient_checkpointing=True)

4.2 自定义数据加载​​

继承 Trainer 并重写 get_train_dataloader 方法:

class CustomTrainer(Trainer):def get_train_dataloader(self):return DataLoader(..., collate_fn=custom_collate)

4.3 ​​多标签分类​​

需要自定义损失函数:

class MultiLabelTrainer(Trainer):def compute_loss(self, model, inputs, return_outputs=False):labels = inputs.pop("labels")outputs = model(**inputs)logits = outputs.logitsloss_fct = nn.BCEWithLogitsLoss()loss = loss_fct(logits, labels.float())return (loss, outputs) if return_outputs else loss

5. 完整工作流示例​​

# 安装依赖:pip install transformers datasets torch
from transformers import AutoModel, AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset
import torch# 1. 加载数据与模型
dataset = load_dataset("imdb")
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)# 2. 数据预处理
def tokenize(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)
dataset = dataset.map(tokenize, batched=True)# 3. 设置训练参数
training_args = TrainingArguments(output_dir="./results",num_train_epochs=3,per_device_train_batch_size=8,evaluation_strategy="epoch"
)# 4. 实例化Trainer
trainer = Trainer(model=model,args=training_args,train_dataset=dataset["train"],eval_dataset=dataset["test"],compute_metrics=lambda p: {"accuracy": (p.predictions.argmax(-1) == p.label_ids).mean()}
)# 5. 训练与评估
trainer.train()
results = trainer.evaluate()
print(f"Final results: {results}")

6. 关键注意事项​​

  • ​​数据格式​​:确保数据集包含 input_ids, attention_mask 和 labels 字段
  • ​​标签处理​​:分类任务标签必须是 0 开始的连续整数
  • ​​设备兼容​​:Trainer 自动使用 GPU(若可用)
  • 模型保存​​:最佳模型保存在 output_dir 下的 best_model 目录
  • 日志查看​​:使用 tensorboard --logdir=./logs 查看训练曲线
    通过 Trainer,80% 的常见训练需求可通过配置解决,剩余 20% 可通过继承重写实现。官方文档提供了更多进阶用例:HuggingFace Trainer 文档
http://www.xdnf.cn/news/928243.html

相关文章:

  • Cloudflare 免费域名邮箱 支持 Catch-all 无限别名收件
  • JAVA理论第四战-线程池
  • 【AI论文】反思、重试、奖励:通过强化学习实现大型语言模型的自我提升
  • archlinux中使用 Emoji 字体
  • keil 5打开编译keil 4解决方案,兼容exe查找下载
  • 编程关键字
  • 【区块链基础】区块链的 Fork(分叉)深度解析:原理、类型、历史案例及共识机制的影响
  • 分类与扩展
  • 【推荐算法】推荐算法演进史:从协同过滤到深度强化学习
  • 「Java基本语法」代码格式与注释规范
  • 第二十七课:手搓梯度提升树
  • AI掘金时代:探讨如何用价值杠杆撬动付费用户增长
  • 记录下three.js学习过程中不理解问题①
  • 测试(面经 八股)
  • 《真假信号》速读笔记
  • Python爬虫实战:研究Unirest库相关技术
  • 王劲松《人民日报》撰文 重读抗战家书不忘来时路
  • Windows小说阅读软件推荐
  • Linux 文件系统核心:inode 与 block 深度解析(附实战案例与源码级原理)
  • 618来了,推荐京东云服务器
  • ROS C++ 实现消息通信与服务通信
  • 交叉熵损失函数和极大似然估计是什么,区别是什么
  • 关于队列的使用
  • 道路运输安全员考试分为哪些科目,各科目重点考察什么?
  • scratch农场小鸡 2024年全国青少年信息素养大赛 图形化编程 scratch变成挑战赛 复赛真题解析
  • string类型
  • Spring IoC 模块设计文档
  • libiec61850 mms协议异步模式
  • 如何构建船舵舵角和船的航向之间的动力学方程?它是一个一阶惯性环节吗?
  • 抖音怎么下载视频