Ultralytics代码详细解析(三:engine->trainer.py主框架)
目录
- 引言
- 一、框架
- 二、详解
- 1. 初始化
- 2. 训练
- 3. 训练设置
- 4. 循环训练
- 5. 验证保存
- 6. 回调函数系统
- 7. 可视化
- 参考链接
- 感谢
引言
之后就正式要进入正文了,这一篇先说engine文件夹吧,毕竟它是启动训练的核心,那么我们就来学习模型训练trainer.py的主框架吧~~~
一、框架
二、详解
1. 初始化
- –init–:主初始化入口;
- check_resume:检查恢复训练点;
- get_dataset:加载和验证数据集;
- setup_model:模型初始化;
- get_model:获取模型(由子类实现);
2. 训练
- train:训练主入口;
- _do_train:实际训练过程;
- _setup_train:设置训练所需的各种组件:模型、冻结层、AMP、优化器等;
- final_eval:最终评估,使用最好的模型进行验证并生成结果。
3. 训练设置
- build_optimizer:构建优化器;
- _setup_scheduler:设置学习率调度器;
- get_dataloader:数据加载;
- build_dataset:构建数据集(由子类实现);
- resume_training:恢复训练;
- _close_dataloader_mosaic:关闭数据加载器的mosaic增强;
- set_model_attributes:设置模型属性;
- auto_batch:自动计算合适的batch大小;
- _setup_ddp:设置分布式训练环境。
4. 循环训练
- _model_train:将模型设置为训练模式,并冻结指定BN层的统计信息;
- preprocess_batch:数据预处理;
-
- build_targets:构建目标(由子类实现);
- optimizer_step:参数更新,执行优化步骤;
- _get_memory:获取当前设备的内存使用情况;
- _clear_memory:清理内存;
- progress_string:返回训练进度字符串(由子类实现)。
5. 验证保存
- validate:在验证集上验证模型性能;
- get_validator:获取验证器实例;
- label_loss_items:返回带标签的损失项,用于记录和打印;
- save_metrics:将训练指标保存到CSV文件;
- read_results_csv:读取训练过程中保存的CSV文件;
- save_model:保存模型,包括last.pt, best.pt等。
6. 回调函数系统
- add_callback:添加特定事件回调函数;
- set_callback:设置特定事件回调函数;
- run_callbacks:运行特定事件回调函数。
7. 可视化
- plot_training_samples:绘制训练样本(由子类实现);
- plot_training_labels:绘制训练标签(由子类实现);
- plot_metrics:绘制训练指标图表;
- on_plot:注册绘图,用于回调函数。
参考链接
无
感谢
深知不易
感谢自己
!!!