当前位置: 首页 > news >正文

PyTorch Lightning(训练评估框架)

PyTorch Lightning

  • 1、教程
  • 2、TensorBoardLogger和SummaryWriter
    • 1. **SummaryWriter**
    • 2. **TensorBoardLogger**
    • 3. 区别对比表
  • 3、 汇总
    • 1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping
      • 核心代码示例:
    • 2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping
    • 3. Stable Baselines3 (强化学习)
    • 4. Huggingface Trainer(NLP)
    • 5. 结合Weights & Biases(W\&B)
    • 总结推荐

1、教程

官方:
https://lightning.ai/docs/pytorch/stable/
https://lightning.ai/docs/overview/getting-started

PyTorch Lightning:
在 GPU、TPU 等设备上对 AI 模型进行微调和预训练。专注于科学,而非工程。

其他:
https://blog.51cto.com/u_16175490/13322417
https://github.com/3017218062/Pytorch-Lightning-Learning/tree/master
https://evernorif.github.io/2024/01/19/Pytorch-Lightning%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8/

训练:

  • 超参数,
    https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html

  • 可视化,

  • https://lightning.ai/docs/pytorch/stable/visualize/logging_basic.html
    https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html
    https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_10.html
    Tensorboard集成:https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html

  • 回调,
    https://lightning.ai/docs/pytorch/stable/levels/advanced_level_16.html

评估:
https://lightning.ai/docs/pytorch/stable/common/evaluation_basic.html
https://lightning.ai/docs/torchmetrics/stable/

数据:
https://lightning.ai/docs/pytorch/stable/levels/intermediate_level_9.html

模型:
https://lightning.ai/docs/pytorch/stable/levels/advanced_level_17.html

注意:PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多(相对复杂)

2、TensorBoardLogger和SummaryWriter

SummaryWriter和TensorBoardLogger

参考:SummaryWriter
https://pytorch.ac.cn/tutorials/intermediate/tensorboard_tutorial.html
https://pytorch.ac.cn/tutorials/beginner/introyt/tensorboardyt_tutorial.html
https://docs.pytorch.ac.cn/docs/stable/tensorboard.html

SummaryWriterTensorBoardLogger 是两个在 PyTorch(尤其是 PyTorch Lightning)中常用的 TensorBoard 日志工具,但它们的定位和使用场景略有不同。


1. SummaryWriter

来源:torch.utils.tensorboard.SummaryWriter(PyTorch 原生提供)
作用:直接向 TensorBoard 写日志,属于底层 API。

特点

  • PyTorch 官方原生工具,不依赖 Lightning。
  • 需要你手动调用 add_scalar()add_histogram()add_graph() 等方法来记录数据。
  • 灵活性很高,但需要自己管理日志目录、日志周期等。
  • 适合 自己实现训练循环 时使用。

常用写法

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter(log_dir="runs/exp1")for step in range(100):loss = 0.1 * stepwriter.add_scalar("Loss/train", loss, step)writer.add_scalar("Accuracy/train", step * 0.01, step)writer.close()

运行:

tensorboard --logdir=runs

2. TensorBoardLogger

来源:pytorch_lightning.loggers.TensorBoardLogger(Lightning 提供)
作用:Lightning 封装的 TensorBoard 日志工具,是高层 API。

特点

  • 专门给 PyTorch Lightning 用的,能自动把训练中的 lossmetrics 等记录到 TensorBoard。
  • 不用手动写 add_scalar(),Lightning 在 trainer.fit() 时会自动调用。
  • 适合 用 Lightning 管理训练流程 的场景。
  • 可以结合 Lightning 的 log() 方法方便地记录。

常用写法

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLoggerlogger = TensorBoardLogger("lightning_logs", name="my_model")trainer = Trainer(max_epochs=10,logger=logger
)trainer.fit(model)

在 Lightning 的 LightningModule 中:

def training_step(self, batch, batch_idx):loss = self(batch)self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)return loss

3. 区别对比表

特性SummaryWriterTensorBoardLogger
所属PyTorch 原生PyTorch Lightning
灵活性高(完全手动记录)中(自动记录 + Lightning API)
是否自动记录训练指标❌ 需要手动调用✅ 自动记录
是否依赖 Lightning
使用场景自己写训练循环用 Lightning 管理训练
API 例子writer.add_scalar()self.log() / logger

总结

  • 如果你用原生 PyTorch 训练 → 用 SummaryWriter,自己控制日志写法。
  • 如果你用 PyTorch Lightning → 用 TensorBoardLogger,自动记录更方便。
  • 其实 TensorBoardLogger 底层就是用 SummaryWriter 实现的,只是 Lightning 帮你做了封装和调用。

如果你愿意,我可以帮你画一个 SummaryWriter vs TensorBoardLogger 的调用流程图,这样能更直观看清它们的关系。你需要吗?

3、 汇总

详细参考下面笔记:
https://blog.csdn.net/weixin_44986037/article/details/150425539?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151153156?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/151225587?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/150427299?spm=1001.2014.3001.5502
https://blog.csdn.net/weixin_44986037/article/details/148544185?spm=1001.2014.3001.5501

原文链接:https://blog.csdn.net/weixin_44986037/article/details/150425539
建议用针对性框架:
Stable Baselines3 : 强化学习
Huggingface Trainer :NLP及Transformer模型
PyTorch Lightning:不追求训练速度,建议用tensorboard替代它,PyTorch Lightning坑有点多

提问:现成的ai模型训练评估框架,包括tersorboard查看训练数据、筛选模型、评估效果

提问:

简单概述主流现成的AI模型训练+评估框架,包含:训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)模型筛选(如自动保存最优模型、早停)评估效果(自动计算指标,方便比较模型)

你想要的是一个现成的AI模型训练+评估框架,包含:

  • 训练过程中的数据监控(比如通过 TensorBoard 可视化损失、指标)
  • 模型筛选(如自动保存最优模型、早停)
  • 评估效果(自动计算指标,方便比较模型)

下面给你推荐几个主流且功能完整的框架/工具组合,覆盖你需求的这几个点:


1. PyTorch Lightning + TensorBoard + ModelCheckpoint + EarlyStopping

  • 训练监控:Lightning内置对TensorBoard的支持,训练时自动记录loss、metric到TensorBoard。
  • 模型筛选:内置 ModelCheckpoint,可自动保存指标最优模型(如val_loss最小、准确率最高)。
  • 早停:内置 EarlyStopping,当验证指标不再提升时自动停止训练。
  • 评估效果:支持训练、验证、测试阶段指标计算和记录。

核心代码示例:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStoppingcheckpoint_callback = ModelCheckpoint(monitor='val_loss',    # 监控验证损失mode='min',           # 取最小值对应模型save_top_k=1,verbose=True,
)early_stop_callback = EarlyStopping(monitor='val_loss',patience=5,verbose=True,mode='min'
)trainer = pl.Trainer(max_epochs=50,callbacks=[checkpoint_callback, early_stop_callback],logger=pl.loggers.TensorBoardLogger('tb_logs/')
)trainer.fit(model, train_dataloader, val_dataloader)

启动后,在终端运行 tensorboard --logdir=tb_logs/ 即可实时查看训练过程。


2. TensorFlow / Keras + TensorBoard + ModelCheckpoint + EarlyStopping

  • TensorFlow内置支持TensorBoard,训练时会自动写入日志。
  • ModelCheckpoint可保存验证指标最优模型。
  • EarlyStopping自动停止训练避免过拟合。

示例:

callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs'),tf.keras.callbacks.ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True),tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
]model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=50, callbacks=callbacks)

3. Stable Baselines3 (强化学习)

  • 自带 EvalCallback 用于训练过程中自动评估模型表现。
  • CheckpointCallback 用于保存模型。
  • 训练日志可通过 TensorBoard 监控。

示例:

from stable_baselines3.common.callbacks import EvalCallback, CheckpointCallbackeval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',log_path='./logs/', eval_freq=10000, deterministic=True)
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path='./logs/')model.learn(total_timesteps=100000, callback=[eval_callback, checkpoint_callback])

运行 tensorboard --logdir=./logs 可查看指标变化。


4. Huggingface Trainer(NLP)

  • 内置日志记录(支持TensorBoard、WandB等)
  • 自动保存最优模型和早停(EarlyStoppingCallback)
  • 评估指标自动计算和记录

示例:

from transformers import Trainer, TrainingArguments, EarlyStoppingCallbacktraining_args = TrainingArguments(output_dir='./results',evaluation_strategy='epoch',save_strategy='epoch',logging_dir='./logs',logging_steps=10,load_best_model_at_end=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)trainer.train()

然后使用TensorBoard查看 ./logs 文件夹。


5. 结合Weights & Biases(W&B)

(即WandB)

  • 支持各种框架(PyTorch、TensorFlow、Lightning、SB3、HF等)
  • 自动记录训练日志、模型、超参
  • 支持在线实时可视化和对比多个实验
  • 支持模型版本管理和筛选

官网:https://wandb.ai/site


总结推荐

框架/工具训练监控(TensorBoard)自动保存最优模型早停评估指标简单易用适用范围
PyTorch LightningPyTorch通用
TensorFlow / KerasTensorFlow/Keras
Stable Baselines3部分强化学习
Huggingface TrainerNLP及Transformer模型
Weights & Biases ((WandB))各类深度学习框架

如果你告诉我你用的是哪种框架,我可以帮你写具体的训练+评估+TensorBoard配置示例,或者帮你搭建完整的训练流水线。你看?

http://www.xdnf.cn/news/1486783.html

相关文章:

  • Python+DRVT 从外部调用 Revit:批量创建楼板
  • 基于SpringBoot+Vue的健身房管理系统的设计与实现(代码+数据库+LW)
  • 多环境配置切换机制能否让开发与生产无缝衔接?
  • 【论文阅读】自我进化的AI智能体综述
  • Unity学习----【进阶】Input System学习(一)--导入与基础的设备调用API
  • 《探索C++11:现代语法的内存管理优化“性能指针”(下篇)》
  • LeetCode 面试经典 150 题:移除元素(双指针思想优化解法详解)
  • RICOH理光 Priport DX4443c速印机 印A3的问题
  • 数据结构之二叉树(2)
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘setuptools’问题
  • 嵌入式学习---(ARM)
  • AutoHotkey将脚本编译为exe文件
  • DevOps实战(3) - 使用Arbess+GitLab+Hadess实现Java项目自动化部署
  • 【Java基础|第三十五篇】类加载与反射
  • 如何在Python中使用正则表达式?
  • 基于Apache Flink Stateful Functions的事件驱动微服务架构设计与实践指南
  • Flink TaskManager日志时间与实际时间有偏差
  • 鱼眼相机模型
  • JVM-默背版
  • 实验室服务器配置|通过Docker实现Linux系统多用户隔离与安全防控
  • Flink NetworkBufferPool核心原理解析
  • Android --- SystemUI 导入Android Studio及debug
  • 2025年体制内职业发展相关认证选择指南
  • 超越自动补全:将AI编码助手深度集成到你的开发工作流​​
  • 微信小程序中实现AI对话、生成3D图像并使用xr-frame演示
  • C++ 连接 Redis:redis-plus-plus 安装与使用入门指南
  • 关于npm的钩子函数
  • 【iOS】push,pop和present,dismiss
  • 上架商品合规流程有多条,有的长,有的短,有的需要审核,校验商品的合规性
  • RestTemplate使用 | RestTemplate设置http连接池参数