快速上手Pytorch Lighting框架 | 深度学习入门
快速上手Pytorch Lighting框架 | 深度学习入门
- 前言
- 参考官方文档
- 介绍
- 快速上手
- 基本流程
- 常用接口
- LightningModule
- \_\_init\_\_ & setup()
- \*\_step()
- configure_callbacks()
- configure_optimizers()
- load_from_checkpoint
- Trainer
- 常用参数
- 可选接口
- Loggers
- TensorBoard Logger
- Callbacks
- EarlyStopping
- ModelCheckpoint
- ProgressBar
前言
本文将介绍一个深度学习的训练框架——Pytorch Lighting框架。首先会介绍Pytorch Lighting框架的特点,然后会聚焦于你使用该框架时一定会使用的那些接口,包括我个人学习该框架时的经验传授。
参考官方文档
- Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.5.1.post0 documentation
- Lightning in 15 minutes — PyTorch Lightning 2.5.1.post0 documentation
- How to Organize PyTorch Into Lightning — PyTorch Lightning 2.5.1.post0 documentation
介绍
Pytorch Lightning是一个基于Pytorch的深度学习与机器学习的框架,它进一步封装Pytorch的接口,简化了深度学习训练代码的搭建过程,帮助用户能够关注于模型本身,而不需要再反复书写重复的训练代码。
Pytorch Lighting框架本质是对Pytorch的进一步封装,所以如果熟悉Pytorch框架,那么很容易上手Pytorch Lighting。结合官方文档以及个人使用体验,相比Pytorch,我认为Pytorch Lightning具有以下特点:
- 代码复用性:Pytorch Lightning提供训练流程的所有接口,可以通过继承的方式,准备训练不同阶段的组件,从而在相似任务之间使用同一份代码。
- 代码可读性:原本的Pytorch代码被进一步封装到框架中,让代码的聚合程度更高,训练流程更清晰,提高了代码的可读性。
- 灵活性:通过框架类方法,可以根据需求定制特定环节的计算逻辑,精细控制训练的每个细节。
- 可移植性:Pytorch Lightning的框架添加了自动检测训练设备的功能,同一份代码可以不仅在本地的CPU上训练,也可以通过远程服务器使用多GPU训练。
- 自动化:框架集成了一些训练会用到的工具,比如日志输出、检查点记录等等。
更多详细内容,可以查阅官方文档介绍!
快速上手
基本流程
使用Lighting框架训练一个深度学习模型,遵循以下的流程:
- 安装Pytorch Lighting
- 定义Pytorch Lighting模块
- 定义数据集(生成样本迭代器)
- 配置训练器,训练模型
- 使用模型:包括测试模型或使用模型预测…
- 可视化训练过程
常用接口
上一小节,简单介绍了使用Pytorch Lighting框架的流程。其本质和普通的机器学习训练流程是一致的,如果只是简单的使用PL框架,几乎可以不输入多余的参数,就能直接开始训练,PL会帮助你完成大量的任务。同时框架提供了训练流程中每一步的对应接口,让用户可以根据需求,修改不同的细节。本小节中将具体介绍这些重要的接口,主要对应上述流程的第2步、第4步及第5步。
对于第3步,PL训练时需要迭代器类型的输入,可以手动生成样本迭代器,也可以使用Pytorch中的Dataloader等,此处将不再展开。
LightningModule
LightningModule
是框架的核心部件,该类中提供关于训练的所有核心方法,涵盖6个方面:
- 模型初始化:init & setup()
- 训练循环:training_step()
- 验证循环:validation_step()
- 测试循环:test_step()
- 预测循环:predict_step()
- 优化器及学习率调整
__init__ & setup()
与类的基本使用方法相同,在LightingModule类的构造函数中,需要对类做必要的初始化,比如导入核心模型结构、优化器方法、损失函数类型等等。
setup(_trainer_, _pl_module_, _stage_)
setup()本质是一个回调函数,功能也是对类进行初始化设置,一般用于不同的训练阶段(predict,test,…)。调用该接口可以在不同的阶段采用不同的初始化策略。
*_step()
在不同的阶段的循环步中,可以部署期望的任务结果。除了基本的前馈计算、反向传播等操作,可以添加日志输出、指标收集等等。比如在train阶段,只获取loss指标;在test阶段,同时获取loss指标、acc指标等。
configure_callbacks()
通过重写该方法,可以定制训练所需的回调函数。当模型被调用的时候,比如执行test()的时候,框架会自动调用这些回调函数。
如果与Trainer中的回调函数表有冲突时,框架会优先使用此处的回调函数配置。
configure_optimizers()
该方法下,可以配置训练过程中使用的优化器类型以及具体的学习率。在常规模型的训练中,只会配置一个优化器,那么返回值就是单个优化器。如果是GANs或其他需要多个优化器的模型,支持返回多个迭代器,但是需要手动进行模型优化,即需要配置optimizer_step()
方法。
load_from_checkpoint
load_from_checkpoint(_checkpoint_path_, _map_location=None_, _hparams_file=None_, _**kwargs_)
一般在测试阶段会需要调用该函数,用一个已经训练好的模型来初始化LightingModule类。checkpoint_path是训练好的模型的.ckpt文件存储位置,PL框架也支持传入URL,或一个类。
TIPS: 如果构造函数传入超参数,记得在构造函数中调用调用self.save_hyperparameters()
。这样框架才会自动保存这些超参数到.ckpt文件中。否则如果训练、测试阶段分开进行时,需要重新导入模型,则需要准备.yaml文件,或超参数列表,才能正确的初始化模型。
Trainer
如果完成了LightningModule
的配置,直接实例化一个训练器Trainer,便可以直接开始训练,默认生成的Train可以自动的帮助你完成所有训练任务:
model = MyLightningModule()trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
模型完成训练后,单独调用test()、validate()方法,对模型进行测试、验证。如果有特殊的训练、测试、验证需求,可以在实例化Trainer的时候进行配置。
常用参数
- accelerator & devices::
该参数是PL框架的特点之一,只需要实例化不同的Trainer就可以实现在不同硬件设备下的训练。也可以不指定参数,框架会自动匹配对应设备完成训练。
accelerator = ["cpu"] ["gpu"] ["tpu"] ["hpu"] ["auto"]
devices = [number of devices] ["auto"]
- callbacks:: 传入单个回调类或回调列表。当传入的是列表时,框架会自动根据顺序逐个调用回调类。如果在PL框架中重写了configure_callbacks()方法,则以框架中的回调类优先。
- max_epochs:: 最大的训练周期。
- enable_progress_bar:: 是否显示进度条,默认将会为True。
- logger:: 传入一个Loggers的实例,默认会使用TensorBoard Logger。设置为
False
则会禁用日志功能。 - log_every_n_steps:: 日志记录的步长
- strategy:: 训练策略,如ddp, fsdp等。
- limit_train_batches:: 限制训练时的batch数量,一般在调试时使用。传入一个数字,当数字小于1时按比例计算【0.25,则使用Dataloader总数的25%的batch】;当数字大于1时按个数计算【5,则使用5个batch】
可选接口
Loggers
在PL中,继承自基类Logger
有多种log格式可选,比如MLflow Logger,CSV logger,TensorBoard Logger等等。可以根据自己的需要,使用不同的日志记录形式。此处着重介绍TensorBoard Logger。
TensorBoard Logger
调用该类,日志将会以tensorboard格式进行记录,训练结束后可以可视化看到训练过程。
TensorBoardLogger(_save_dir_, _name='lightning_logs'_, _version=None_, _log_graph=False_, _default_hp_metric=True_, _prefix=''_, _sub_dir=None_, _**kwargs_)
重要的参数是save_dir
,name
,version
。因为这将决定日志的保存位置:save_dir/name/version。在不同的训练阶段可以实例化不同的logger,就可以将不同的阶段的日志放置在不同路径,方便分析研究。
构建好Logger的实例后,作为参数传入到Trainer
中即可,以下是官方文档中的例子:
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLoggerlogger = TensorBoardLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)
Callbacks
EarlyStopping
通过该类配置训练早停的策略。
EarlyStopping(_monitor_, _min_delta=0.0_, _patience=3_, _verbose=False_, _mode='min'_, _strict=True_, _check_finite=True_, _stopping_threshold=None_, _divergence_threshold=None_, _check_on_train_epoch_end=None_, _log_rank_zero_only=False_)
- monitor:: 监视指标。
- patience:: 传入一个整数n。默认情况下,每个epoch后都会检查指标的数值,当指标n次检查都一样时会触发早停。
- mode:: 可选max或min模式:max模式下,指标不再增长时会触发早停;min模式下,指标不再下降时会触发早停。
ModelCheckpoint
通过该类配置模型保存的保存策略。
ModelCheckpoint(_dirpath=None_, _filename=None_, _monitor=None_, _verbose=False_, _save_last=None_, _save_top_k=1_, _save_weights_only=False_, _mode='min'_, _auto_insert_metric_name=True_, _every_n_train_steps=None_, _train_time_interval=None_, _every_n_epochs=None_, _save_on_train_epoch_end=None_, _enable_version_counter=True_)
- dirpath & filename:: 模型文件将存储为dirpath/filename。
- monitor:: 评价指标,需要搭配save_top_k选项一起使用。
- save_top_k:: 传入一个整数n,指定保存模型的数量。
- n为0,不会保存模型。
- n为-1,会保存所有检查点时的模型。
- n大于2,模型会保存指标最好的n个模型。
ProgressBar
通过继承该类,重写成员方法,以按需求定制进度条的形式。
- get_metrics :: 可以从基类获得所有指标,然后返回想要显示的指标的字典
- print:: 定制进度条的输出样式。原文提到without breaking the progress bar.,应该是要注意输出的方式,比如不能重新刷新屏幕缓冲区?