PyTorch Lightning(训练评估框架)
PyTorch Lightning
- 1、教程
- 2、TensorBoardLogger和SummaryWriter
- 1. **SummaryWriter**
- 2. **TensorBoardLogger**
- 3. 区别对比表
- 3、 汇总
- 1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping
- 核心代码示例:
- 2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping
- 3. Stable Baselines3 (强化学习)
- 4. Huggingface Trainer(NLP)
- 5. 结合Weights & Biases(W\&B)
- 总结推荐
1、教程
官方:
https://lightning.ai/docs/pytorch/stable/
https://lightning.ai/docs/overview/getting-started
PyTorch Lightning:
在 GPU、TPU 等设备上对 AI 模型进行微调和预训练。专注于科学,而非工程。
其他:
https://blog.51cto.com/u_16175490/13322417
https://github.com/3017218062/Pytorch-Lightning-Learning/tree/master
https://evernorif.github.io/2024/01/19/Pytorch-Lightning%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8/
训练:
-
超参数,
https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html -
可视化,
-
https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html
https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_10.html
Tensorboard集成:https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html -
回调,
https://lightning.ai/docs/pytorch/stable/levels/advanced_level_16.html
评估:
https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html
https://lightning.ai/docs/torchmetrics/stable/
数据:
https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_9.html
模型:
https://lightning.ai/docs/pytorch/stable/levels/advanced_level_17.html
注意:PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多(相对复杂)
2、TensorBoardLogger和SummaryWriter
SummaryWriter和TensorBoardLogger
参考:SummaryWriter
https://pytorch.ac.cn/tutorials/intermediate/tensorboard_tutorial.html
https://pytorch.ac.cn/tutorials/beginner/introyt/tensorboardyt_tutorial.html
https://docs.pytorch.ac.cn/docs/stable/tensorboard.html
SummaryWriter
和 TensorBoardLogger
是两个在 PyTorch(尤其是 PyTorch Lightning)中常用的 TensorBoard 日志工具,但它们的定位和使用场景略有不同。
1. SummaryWriter
来源:torch.utils.tensorboard.SummaryWriter
(PyTorch 原生提供)
作用:直接向 TensorBoard 写日志,属于底层 API。
特点:
- 是 PyTorch 官方原生工具,不依赖 Lightning。
- 需要你手动调用
add_scalar()
、add_histogram()
、add_graph()
等方法来记录数据。 - 灵活性很高,但需要自己管理日志目录、日志周期等。
- 适合 自己实现训练循环 时使用。
常用写法:
from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter(log_dir="runs/exp1")for step in range(100):loss = 0.1 * stepwriter.add_scalar("Loss/train", loss, step)writer.add_scalar("Accuracy/train", step * 0.01, step)writer.close()
运行:
tensorboard --logdir=runs
2. TensorBoardLogger
来源:pytorch_lightning.loggers.TensorBoardLogger
(Lightning 提供)
作用:Lightning 封装的 TensorBoard 日志工具,是高层 API。
特点:
- 专门给 PyTorch Lightning 用的,能自动把训练中的
loss
、metrics
等记录到 TensorBoard。 - 不用手动写
add_scalar()
,Lightning 在trainer.fit()
时会自动调用。 - 适合 用 Lightning 管理训练流程 的场景。
- 可以结合 Lightning 的
log()
方法方便地记录。
常用写法:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLoggerlogger = TensorBoardLogger("lightning_logs", name="my_model")trainer = Trainer(max_epochs=10,logger=logger
)trainer.fit(model)
在 Lightning 的 LightningModule
中:
def training_step(self, batch, batch_idx):loss = self(batch)self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)return loss
3. 区别对比表
特性 | SummaryWriter | TensorBoardLogger |
---|---|---|
所属 | PyTorch 原生 | PyTorch Lightning |
灵活性 | 高(完全手动记录) | 中(自动记录 + Lightning API) |
是否自动记录训练指标 | ❌ 需要手动调用 | ✅ 自动记录 |
是否依赖 Lightning | ❌ | ✅ |
使用场景 | 自己写训练循环 | 用 Lightning 管理训练 |
API 例子 | writer.add_scalar() | self.log() / logger |
✅ 总结:
- 如果你用原生 PyTorch 训练 → 用
SummaryWriter
,自己控制日志写法。 - 如果你用 PyTorch Lightning → 用
TensorBoardLogger
,自动记录更方便。 - 其实
TensorBoardLogger
底层就是用SummaryWriter
实现的,只是 Lightning 帮你做了封装和调用。
如果你愿意,我可以帮你画一个 SummaryWriter vs TensorBoardLogger 的调用流程图,这样能更直观看清它们的关系。你需要吗?
3、 汇总
详细参考下面笔记:
https://blog.csdn.net/weixin_44986037/article/details/150425539?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151153156?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151225587?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/150427299?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/148544185?spm=1001.2014.3001.5501
原文链接:https://blog.csdn.net/weixin_44986037/article/details/150425539
建议用针对性框架:
Stable Baselines3 : 强化学习
Huggingface Trainer :NLP及Transformer模型
PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多
提问:现成的ai模型训练评估框架,包括tersorboard查看训练数据、筛选模型、评估效果
提问:
简单概述主流现成的AI模型训练+评估框架,包含:训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)模型筛选(如自动保存最优模型、早停)评估效果(自动计算指标,方便比较模型)
你想要的是一个现成的AI模型训练+评估框架,包含:
- 训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)
- 模型筛选(如自动保存最优模型、早停)
- 评估效果(自动计算指标,方便比较模型)
下面给你推荐几个主流且功能完整的框架/工具组合,覆盖你需求的这几个点:
1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping
- 训练监控:Lightning内置对TensorBoard的支持,训练时自动记录loss、metric到TensorBoard。
- 模型筛选:内置
ModelCheckpoint
,可自动保存指标最优模型(如val_loss最小、准确率最高)。 - 早停:内置
EarlyStopping
,当验证指标不再提升时自动停止训练。 - 评估效果:支持训练、验证、测试阶段指标计算和记录。
核心代码示例:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStoppingcheckpoint_callback = ModelCheckpoint(monitor='val_loss', # 监控验证损失mode='min', # 取最小值对应模型save_top_k=1,verbose=True,
)early_stop_callback = EarlyStopping(monitor='val_loss',patience=5,verbose=True,mode='min'
)trainer = pl.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stop_callback],logger=pl.loggers.TensorBoardLogger('tb_logs/')
)trainer.fit(model, train_dataloader, val_dataloader)
启动后,在终端运行 tensorboard --logdir=tb_logs/
即可实时查看训练过程。
2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping
- TensorFlow内置支持TensorBoard,训练时会自动写入日志。
- ModelCheckpoint可保存验证指标最优模型。
- EarlyStopping自动停止训练避免过拟合。
示例:
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True),tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
]model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=50, callbacks=callbacks)
3. Stable Baselines3 (强化学习)
- 自带
EvalCallback
用于训练过程中自动评估模型表现。 CheckpointCallback
用于保存模型。- 训练日志可通过 TensorBoard 监控。
示例:
from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallbackeval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',log_path='./logs/', eval_freq=10000, deterministic=True)
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path='./logs/')model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback])
运行 tensorboard --logdir=./logs
可查看指标变化。
4. Huggingface Trainer(NLP)
- 内置日志记录(支持TensorBoard、WandB等)
- 自动保存最优模型和早停(EarlyStoppingCallback)
- 评估指标自动计算和记录
示例:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallbacktraining_args = TrainingArguments(output_dir='./results',evaluation_strategy='epoch',save_strategy='epoch',logging_dir='./logs',logging_steps=10,load_best_model_at_end=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)trainer.train()
然后使用TensorBoard查看 ./logs
文件夹。
5. 结合Weights & Biases(W&B)
(即WandB)
- 支持各种框架(PyTorch、TensorFlow、Lightning、SB3、HF等)
- 自动记录训练日志、模型、超参
- 支持在线实时可视化和对比多个实验
- 支持模型版本管理和筛选
官网:https://wandb.ai/site
总结推荐
框架/工具 | 训练监控(TensorBoard) | 自动保存最优模型 | 早停 | 评估指标 | 简单易用 | 适用范围 |
---|---|---|---|---|---|---|
PyTorch Lightning | ✅ | ✅ | ✅ | ✅ | ✅ | PyTorch通用 |
TensorFlow / Keras | ✅ | ✅ | ✅ | ✅ | ✅ | TensorFlow/Keras |
Stable Baselines3 | ✅ | ✅ | 部分 | ✅ | ✅ | 强化学习 |
Huggingface Trainer | ✅ | ✅ | ✅ | ✅ | ✅ | NLP及Transformer模型 |
Weights & Biases ((WandB)) | ✅ | ✅ | ✅ | ✅ | ✅ | 各类深度学习框架 |
如果你告诉我你用的是哪种框架,我可以帮你写具体的训练+评估+TensorBoard配置示例,或者帮你搭建完整的训练流水线。你看?