PyTorch 学习率调度器(LR Scheduler)
文章目录
- PyTorch 学习率调度器(LR Scheduler)
- 1. 一句话定义
- 2. 通用使用套路
- 3. 内置调度器对比速览
- 4. 各调度器最小模板
- ① LambdaLR(线性 warmup)
- ② StepLR
- ③ MultiStepLR
- ④ CosineAnnealingLR
- ⑤ ReduceLROnPlateau(必须传指标)
- 5. 常用调试 API
- 6. 易踩坑 Top-3
- 7. 速记口诀
PyTorch 学习率调度器(LR Scheduler)
1. 一句话定义
每过一段时间 / 满足某条件,自动按规则修改优化器学习率的工具。
2. 通用使用套路
optimizer = torch.optim.Adam(model.parameters(), lr=初始LR)
scheduler = XXXLR(optimizer, ...) # 选下面任意一种
for epoch in range(EPOCH):train(...)val_loss = validate(...)optimizer.step() # ① 先更新参数scheduler.step(val_loss) # ② 再调度LR(ReduceLROnPlateau需传loss)
顺序:先 optimizer.step()
→ 再 scheduler.step()
,否则报警告。
3. 内置调度器对比速览
调度器 | 触发规则 | 主要参数 | 参数解释 | 典型场景 |
---|---|---|---|---|
LambdaLR | 自定义函数 f(epoch) 返回乘数 | lr_lambda , last_epoch | lr_lambda : 接收 epoch,返回 LR 乘数;last_epoch : 重启训练时设为上次 epoch | warmup、分段线性 |
StepLR | 固定每 step_size epoch 降一次 | step_size , gamma , last_epoch | step_size : 隔多少 epoch 降;gamma : 乘性衰减系数 | 常规“等间隔”下降 |
MultiStepLR | 指定里程碑 epoch 列表降 | milestones , gamma , last_epoch | milestones : List,到这些 epoch 就 ×gamma | 训练中期多段下降 |
CosineAnnealingLR | 余弦曲线从初始→η_min | T_max , eta_min , last_epoch | T_max : 半个余弦周期长度;eta_min : 最小 LR | 退火、cosine 重启 |
ReduceLROnPlateau | 监控指标停止改善时降 | mode , factor , patience , threshold , cooldown , min_lr | 见下方详注 | 验证 loss/acc 卡住时 |
ReduceLROnPlateau 参数详注
mode='min'
或'max'
:指标越小/越大越好factor=0.1
:新 LR = 旧 LR × factorpatience=3
:连续 3 次 epoch 无改善才降threshold=0.01
:改善幅度小于阈值视为无改善cooldown=1
:降 LR 后冻结监控的 epoch 数min_lr=1e-6
:下限,降到此值不再降
4. 各调度器最小模板
① LambdaLR(线性 warmup)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: epoch / 5 if epoch < 5 else 1)
② StepLR
scheduler = StepLR(optimizer, step_size=2, gamma=0.1) # 每 2 epoch ×0.1
③ MultiStepLR
scheduler = MultiStepLR(optimizer, milestones=[2, 6], gamma=0.1)
④ CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
⑤ ReduceLROnPlateau(必须传指标)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3,threshold=0.01, cooldown=1, min_lr=1e-6)
val_loss = validate(...)
scheduler.step(val_loss) # ← 记得传指标
5. 常用调试 API
scheduler.get_last_lr() # 当前实际 LR 列表(每个 param_group)
scheduler.last_epoch # 已完成的 epoch 计数(从 0 开始)
6. 易踩坑 Top-3
- 先
optimizer.step()
再scheduler.step()
否则报警告 “Detected call oflr_scheduler.step()
beforeoptimizer.step()
”。 - ReduceLROnPlateau 必须传监控值
不传 → RuntimeError。 - Lambda/MultiStep 等无需监控值,传了 → TypeError。
7. 速记口诀
“优化先迈步,调度再跟进;Plateau 传 loss,其余不用问。”