当前位置: 首页 > news >正文

基于 Transformer robert的情感分类任务实践总结之三——FGM

FGM(Fast Gradient Method)对抗训练详解

FGM(Fast Gradient Method,快速梯度方法)是一种简单而有效的对抗训练方法,旨在提高深度学习模型在面对对抗样本时的鲁棒性。对抗样本是经过微小扰动(肉眼难以察觉)但能导致模型错误分类的输入样本。

为什么需要对抗训练?

深度学习模型在许多任务上都取得了SOTA性能,但它们往往对输入中的微小、有意的扰动非常敏感。这些扰动可以导致模型输出错误的预测。对抗训练通过在训练过程中引入这些对抗扰动,使模型学习如何更好地处理它们,从而提高模型的鲁棒性和泛化能力。


FGM 的核心思想

本质思想:在训练阶段,主动给模型输入加一点"微小扰动"(adversarial perturbation),逼迫模型在 “坏情况” 下也能做出正确预测。
FGM 的核心思想是在模型训练过程中,通过计算损失函数对输入Embedding的梯度,生成一个微小的扰动,并将其添加到原始输入 Embedding 上,形成一个对抗样本。然后,模型会同时学习如何正确分类原始样本和这些对抗样本。

具体来说,FGM 认为,模型之所以容易被对抗样本愚弄,是因为模型在某些方向上的梯度过大,导致微小的输入变化就能引起输出的剧烈变化。FGM 的目标是平滑这些敏感方向,让模型对这些扰动不那么敏感。


FGM 的执行流程

FGM 的实现非常巧妙,它在 PyTorch 的 backwardoptimizer.step 之间插入了一个额外的“攻击-反向传播-恢复”步骤。

以下是带有 FGM 的训练流程:

  1. 正常前向传播 (forward):
    模型接收原始输入 x x x,进行前向传播,得到预测输出 y ^ \hat{y} y^
  2. 计算损失 (compute_loss):
    根据 y ^ \hat{y} y^ 和真实标签 y y y 计算损失 L L L(例如,交叉熵损失)。
  3. 第一次反向传播 (backward):
    对损失 L L L 进行反向传播,计算模型参数的梯度。请注意,此时我们只计算了梯度,但还没有更新模型参数。
  4. FGM 攻击 (FGM.attack()):
    这是 FGM 的关键步骤。它利用第一步反向传播得到的Embedding 梯度信息,计算出一个微小的扰动 r r r
    r = ϵ ⋅ ∇ x L ∥ ∇ x L ∥ r = \epsilon \cdot \frac{\nabla_x L}{\|\nabla_x L\|} r=ϵxLxL
    其中, ∇ x L \nabla_x L xL 是损失函数 L L L 对输入 Embedding x x x 的梯度, ϵ \epsilon ϵ 是一个预设的扰动步长(FGM_EPSILON)。
    然后,这个扰动 r r r直接添加到原始 Embedding 上。这一步通过修改模型中 Embedding 层的参数值来模拟对输入 Embedding 的扰动。
  5. 对抗样本前向传播 (forward_adv):
    模型在带有扰动的 Embedding 的基础上再次进行前向传播,得到对抗样本的预测输出 y ^ a d v \hat{y}_{adv} y^adv
  6. 计算对抗损失 (compute_loss_adv):
    根据 y ^ a d v \hat{y}_{adv} y^adv 和真实标签 y y y 计算对抗损失 L a d v L_{adv} Ladv
  7. 第二次反向传播 (backward_adv):
    对对抗损失 L a d v L_{adv} Ladv 进行反向传播,计算模型参数的梯度。这些梯度会与第一次反向传播累积的梯度叠加。这意味着,模型在一次迭代中,同时考虑了原始样本的损失梯度和对抗样本的损失梯度。
  8. FGM 恢复 (FGM.restore()):
    在计算完对抗损失并反向传播之后,FGM 会将 Embedding 参数恢复到 FGM 攻击之前的状态。这是因为 FGM 只是利用了对抗扰动来计算额外的梯度,而不是永久性地改变 Embedding 参数。
  9. 优化器更新 (optimizer.step):
    优化器根据累积的梯度(来自原始损失和对抗损失)更新模型的参数。
  10. 学习率调度器更新 (update model):
    学习率调度器更新学习率。

FGM 的优势与局限性

优势:

  • 简单有效: FGM 算法原理直观,实现相对简单,且在许多任务上都能有效提升模型鲁棒性。
  • 计算效率高: 相较于更复杂的对抗训练方法(如 PGD),FGM 只需要进行一次额外的反向传播,计算开销相对较小。
  • 兼容性好: FGM 可以很容易地集成到现有的训练框架中,如 Hugging Face Transformers。

局限性:

  • 单步攻击: FGM 是一种单步攻击方法,可能无法生成像 PGD(Projected Gradient Descent)等多步攻击方法那样强大的对抗样本,因此模型对更复杂的攻击可能仍然脆弱。
  • 无法保证绝对鲁棒性: 对抗训练旨在提高鲁棒性,但并不能保证模型在所有可能的对抗攻击下都完全安全。

FGMCallback 在 Trainer 中的作用

在 Hugging Face Trainer 中集成 FGM,通常会通过自定义 Trainer 或使用 TrainerCallback 来实现。你提供的代码中,FGMCallback 正是扮演了这个角色。

FGMCallback 的核心在于 on_after_backward 方法。这个方法在每次批量训练的梯度计算完毕后(即 loss.backward() 之后,optimizer.step() 之前)被 Trainer 调用。这正是 FGM 插入其逻辑的理想时机:

  • self.fgm.attack(): 在这里执行 FGM 的攻击步骤,向 Embedding 添加扰动。
  • adv_outputs1 = model(...)adv_outputs2 = model(...): 模型使用带有扰动的 Embedding 进行前向传播,这里因为结合了 R-Drop,所以进行了两次前向传播以计算 R-Drop 的 KL 散度损失。
  • adv_loss = self.rdrop_loss_fn(...): 计算基于对抗样本的损失(这里是 R-Drop 损失)。
  • adv_loss.backward(): 对对抗损失进行反向传播,将额外的梯度累积到模型参数上。
  • self.fgm.restore(): 恢复 Embedding 参数到原始状态,避免永久修改。

通过这种方式,FGMCallback 确保了在每次模型参数更新之前,都会额外计算一次对抗损失的梯度,从而迫使模型学习如何抵抗微小的对抗扰动。


总结

FGM 是一种实用且高效的对抗训练技术,通过在训练过程中生成并利用对抗扰动来增强模型的鲁棒性。它通过在模型参数更新之前插入额外的“攻击-反向传播-恢复”步骤,有效地将对抗样本的梯度信息融入到模型的学习过程中,从而使模型在面对恶意干扰时表现更稳定。

代码:

# Advanced RoBERTa Sentiment Classifier with R-Drop + FGM + LabelSmoothing + CosineAnnealing"""
执行流程:正常流程:
forward → compute_loss → backward → optimizer.step → update model加入 FGM 后流程:
forward → compute_loss → backward → FGMCallback.on_after_backward:→ FGM.attack() → forward_adv → compute_loss_adv → backward_adv → FGM.restore() → optimizer.step → update model
"""import os
import numpy as np
import torch
import torch.nn as nn
from transformers import (AutoTokenizer,AutoModelForSequenceClassification,Trainer,TrainingArguments,DataCollatorWithPadding,set_seed,EarlyStoppingCallback,TrainerCallback
)
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score# 固定随机种子
set_seed(42)# 配置参数
MODEL_NAME = "roberta-base"
NUM_LABELS = 2
R_DROP_ALPHA = 5.0
LABEL_SMOOTHING = 0.1
FGM_EPSILON = 1.0# 加载数据
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)def preprocess_function(examples):return tokenizer(examples["text"], truncation=True)train_dataset = train_dataset.map(preprocess_function, batched=True)
test_dataset = test_dataset.map(preprocess_function, batched=True)# 数据整理器
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)# 加载模型
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS
)# --- R-Drop Loss ---
class RDropLoss(nn.Module):def __init__(self, alpha=1.0, label_smoothing=0.0):super().__init__()self.alpha = alphaself.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)self.kl = nn.KLDivLoss(reduction="batchmean")def forward(self, logits1, logits2, labels):# CE loss(两次 forward)ce_loss1 = self.ce(logits1, labels)ce_loss2 = self.ce(logits2, labels)ce_loss = 0.5 * (ce_loss1 + ce_loss2)# KL divergence lossp = torch.log_softmax(logits1, dim=-1)q = torch.log_softmax(logits2, dim=-1)p_softmax = torch.softmax(logits1, dim=-1)q_softmax = torch.softmax(logits2, dim=-1)kl_loss = 0.5 * (self.kl(p, q_softmax) + self.kl(q, p_softmax))# 总 loss = CE + alpha * KLreturn ce_loss + self.alpha * kl_loss# --- FGM 对抗训练模块 ---
class FGM:def __init__(self, model, epsilon=1.0):self.model = modelself.epsilon = epsilonself.backup = {}def attack(self, emb_name='embeddings.word_embeddings'):"""执行 attack:- 在 embedding 层添加 adversarial 噪声- 不修改模型的结构,只是修改参数值(加上扰动)"""for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name and param.grad is not None:# 保存原始参数self.backup[name] = param.data.clone()# 计算扰动方向norm = torch.norm(param.grad)if norm != 0:r_at = self.epsilon * param.grad / norm# 添加扰动param.data.add_(r_at)def restore(self, emb_name='embeddings.word_embeddings'):"""恢复 embedding 参数到原始状态(去除 adversarial 扰动)"""for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name and name in self.backup:param.data = self.backup[name]self.backup = {}# --- FGM Callback ---
class FGMCallback(TrainerCallback):def __init__(self, fgm, rdrop_loss_fn):self.fgm = fgmself.rdrop_loss_fn = rdrop_loss_fndef on_after_backward(self, args, state, control, model=None, inputs=None, optimizer=None, **kwargs):"""FGM 主逻辑在这里做:1 attack → 修改 embedding 参数2 forward_adv → 再次前向传播3 compute adversarial loss4 backward_adv → 对 adversarial loss 做反向传播5 restore → 恢复 embedding 参数"""# 1 Attackself.fgm.attack()# 2 Forward again (adv)adv_outputs1 = model(**{k: v for k, v in inputs.items() if k != "labels"})adv_outputs2 = model(**{k: v for k, v in inputs.items() if k != "labels"})adv_logits1 = adv_outputs1.logitsadv_logits2 = adv_outputs2.logitslabels = inputs["labels"]# 3 Compute adversarial lossadv_loss = self.rdrop_loss_fn(adv_logits1, adv_logits2, labels)# 4 Backward on adversarial lossaccelerator = kwargs.get("accelerator", None)if accelerator is not None:accelerator.backward(adv_loss)else:adv_loss.backward()# 5 Restore model paramsself.fgm.restore()# 后续 optimizer.step() / scheduler.step() 由 Trainer 自动完成# --- 评价指标 ---
def compute_metrics(eval_pred):logits, labels = eval_predprobs = torch.softmax(torch.tensor(logits), dim=-1).numpy()predictions = np.argmax(logits, axis=-1)acc = accuracy_score(labels, predictions)f1 = f1_score(labels, predictions)try:auc = roc_auc_score(labels, probs[:, 1])except:auc = 0.0return {"accuracy": acc, "f1": f1, "auc": auc}# --- 自定义 Trainer ---
class AdvancedTrainer(Trainer):def __init__(self, *args, alpha=1.0, label_smoothing=0.0, **kwargs):super().__init__(*args, **kwargs)self.rdrop_loss_fn = RDropLoss(alpha=alpha, label_smoothing=label_smoothing)def compute_loss(self, model, inputs, return_outputs=False,**kwargs):labels = inputs["labels"]
# 如果传 labels,outputs1.loss 是被 自动计算出来的 CE loss,你就无法只拿 logits 计算 R-Drop KL Loss。# R-Drop forwardoutputs1 = model(**{k: v for k, v in inputs.items() if k != "labels"})outputs2 = model(**{k: v for k, v in inputs.items() if k != "labels"})logits1 = outputs1.logitslogits2 = outputs2.logits# Compute R-Drop lossloss = self.rdrop_loss_fn(logits1, logits2, labels)return (loss, outputs1) if return_outputs else loss# --- Trainer 参数 ---
training_args = TrainingArguments(output_dir="./results_adv_rdrop",eval_strategy="epoch",save_strategy="epoch",learning_rate=2e-5,per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=5,weight_decay=0.01,warmup_ratio=0.1,lr_scheduler_type="cosine",logging_dir="./logs_adv_rdrop",logging_steps=50,load_best_model_at_end=True,metric_for_best_model="f1",fp16=True,save_total_limit=2,
)# 初始化 FGM
fgm = FGM(model, epsilon=FGM_EPSILON)# 初始化 EarlyStopping
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)# 初始化 Trainer
trainer = AdvancedTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=test_dataset,processing_class=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics,alpha=R_DROP_ALPHA,label_smoothing=LABEL_SMOOTHING,callbacks=[FGMCallback(fgm=fgm, rdrop_loss_fn=RDropLoss(alpha=R_DROP_ALPHA, label_smoothing=LABEL_SMOOTHING)),early_stopping_callback],
)# --- 开始训练 ---
trainer.train()# --- 评估 ---
trainer.evaluate()

##tensorboard

在这里插入图片描述

http://www.xdnf.cn/news/932653.html

相关文章:

  • 从代码学习深度强化学习 - 多臂老虎机 PyTorch版
  • 【深度学习|学习笔记】自监督学习(Self-Supervised Learning, SSL)在遥感领域中的典型应用案例及其在小样本学习中的作用,附代码。
  • LeetCode --- 452周赛
  • 高保真组件库:按钮
  • GitHub 趋势日报 (2025年06月07日)
  • Langgraph实战-自省式RAG: Self-RAG
  • 材料力学速通
  • 北京工作周期7,8,9,10
  • 【react实战】如何实现监听窗口大小变化
  • 2025HNCTF - Crypto
  • webstorm 配置Eslint
  • Springboot 基于MessageSource配置国际化
  • C#调用Rust动态链接库DLL的案例
  • ​RBAC(基于角色的访问控制)权限管理详解
  • 学习日记-day24-6.8
  • 鸿蒙API自翻译
  • 70常用控件_QVBoxLayout的使用
  • 指针的使用——字符、字符串、字符串数组(char*)
  • C++进阶--C++11--智能指针(重点)
  • 12.7Swing控件6 JList
  • gitcode与github加速计划
  • LabVIEW Modbus 主站冗余控制
  • css | class中 ‘.‘ 和 ‘:‘ 的使用 | 如,何时用 .is-selected{ ... } 何时用 :hover{...}?
  • 3Ds Max 2026安装包+教程网盘下载与安装教程指南
  • [特殊字符] Whisper 模型介绍(OpenAI 语音识别系统)
  • WEB3全栈开发——面试专业技能点P1Node.js / Web3.js / Ethers.js
  • 【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
  • 图神经网络(GNN)模型的基本原理
  • MySQL:CTE 通用表达式
  • 在React 中安装和配置 shadcn/ui