LLM基础7_用于文本分类的微调
基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn
微调的概念
-
预训练:模型在大规模通用文本上学习语言模式(如GPT在互联网文本上训练)
-
微调:在预训练基础上,用特定领域数据继续训练模型
为什么需要微调?
-
领域适应:通用模型在专业领域表现不佳
-
任务定制:使模型适应特定任务(如分类、情感分析)
-
性能提升:微调后模型在特定任务上表现更好
-
数据效率:比从头训练节省90%以上的数据量
文本分类任务概览
文本分类是将文本分配到预定义类别的任务
常见应用场景
-
情感分析(正面/负面/中性)
-
主题分类(体育/政治/科技)
-
垃圾邮件检测
-
意图识别(客服场景)
-
新闻分类
微调流程详解
1. 准备领域数据
-
数据收集:获取与任务相关的文本
-
数据标注:人工或半自动标注类别
-
数据格式:
# CSV格式示例
text,label
"这个产品太好用了",positive
"服务太差,再也不买了",negative
"手机电池续航一般",neutral
2. 添加分类头
-
预训练模型:提供基础语言理解能力
-
分类头:添加在模型顶部的简单神经网络
from transformers import BertForSequenceClassification# 加载预训练模型
model = BertForSequenceClassification.from_pretrained("bert-base-chinese",num_labels=3 # 情感分类的类别数
)
3. 训练调整
- 冻结参数:只训练分类头(轻量微调)
- 全参数训练:更新所有参数(效果更好但资源消耗大)
from transformers import Trainer, TrainingArgumentstraining_args = TrainingArguments(output_dir='./results',num_train_epochs=3,per_device_train_batch_size=16,evaluation_strategy="epoch"
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset
)trainer.train()
微调技术细节
学习率策略
-
预热学习率:开始小学习率,逐渐增大
-
衰减学习率:后期减小学习率
from transformers import get_linear_schedule_with_warmupoptimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=100,num_training_steps=1000
)
类别不平衡处理
当各类别样本数差异大时:
-
重采样:过采样少数类,欠采样多数类
-
类别权重:在损失函数中增加少数类权重
from sklearn.utils.class_weight import compute_class_weightclass_weights = compute_class_weight('balanced', classes, labels)
数据增强技巧
-
回译:中->英->中生成同义句
-
同义词替换:使用词嵌入替换同义词
-
随机插入/删除:增加文本多样性
评估与优化
指标 | 公式 | 适用场景 |
---|---|---|
准确率 | (TP+TN)/(TP+FP+FN+TN) | 类别平衡 |
F1值 | 2(PrecisionRecall)/(Precision+Recall) | 类别不平衡 |
AUC | ROC曲线下面积 | 整体性能 |
常见问题及解决方案
-
过拟合:
-
增加Dropout率
-
添加L2正则化
-
早停(Early Stopping)
-
-
欠拟合:
-
增加训练数据
-
减少正则化
-
延长训练时间
-
-
部署问题:
-
模型量化(减小模型大小)
-
ONNX格式转换(加速推理)
-
实战案例:新闻分类
数据集:THUCNews中文新闻数据集
-
10个类别:体育、财经、房产、教育等
-
每类6500条数据,共6.5万条
# 1. 加载预训练模型
from transformers import BertTokenizer, BertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = BertForSequenceClassification.from_pretrained("bert-base-chinese",num_labels=10
)# 2. 准备数据
def preprocess_function(examples):return tokenizer(examples["text"],padding="max_length",truncation=True,max_length=128)from datasets import load_dataset
dataset = load_dataset("thucnews")
dataset = dataset.map(preprocess_function, batched=True)# 3. 训练配置
from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir="./news_classifier",evaluation_strategy="epoch",learning_rate=2e-5,per_device_train_batch_size=32,num_train_epochs=3,weight_decay=0.01,
)trainer = Trainer(model=model,args=training_args,train_dataset=dataset["train"],eval_dataset=dataset["validation"],
)# 4. 开始训练
trainer.train()# 5. 评估
results = trainer.evaluate()
print(results)
高级技巧
少样本学习(Few-shot Learning)
当标注数据很少时:
1.提示工程(Prompt Engineering):
文本:"苹果发布新款iPhone"
提示:这是一条关于[科技]的新闻
2.模式利用训练(Pattern-Exploiting Training):
-
将分类任务转化为完形填空
-
"这条新闻的主题是____" → 模型预测[MASK]位置
知识蒸馏
-
教师模型:大型高精度模型
-
学生模型:小型高效模型
-
过程:学生模型学习教师模型的输出分布