当前位置: 首页 > news >正文

RTDETRv2 pytorch训练

RTDETRv2 pytorch训练

    • 1. 代码获取
    • 2. 数据集制作
    • 3. 环境配置
    • 4. 代码修改
        • 1)configs/dataset/coco_detection.yml
        • 2) configs/src/data/coco_dataset.py
        • 3)configs/src/core/yaml_utils.py
        • 4)configs/rtdeterv2/include/optimizer.yml
    • 5. 代码训练、验证、以及模型参数和FLOPs
        • 1) 训练
        • 2)验证

1. 代码获取

从github上下载官方源码官方源码,将其中的redetrv2_pytorch单独移动到桌面上

git clone https://github.com/lyuwenyu/RT-DETR.git

2. 数据集制作

使用github仓库中的 yolo2coco_1.py将YOLO标签转化为COCO数据集格式,然后将其转化为如下的存储顺序,最终移动到redetrv2_pytorch中

dataset└── coco├── train2017  ├── val2017 ├── test2017 └── annotations├── instance_train2017.json├── instance_val2017.json└── instance_test2017.json

3. 环境配置

windows 11 GPU4090
模型训练环境为YOLOv9,然后需要单独安装faster-coco-eval库

pip install faster-coco-eval

或者直接

pip install -r requirements.txt

4. 代码修改

1)configs/dataset/coco_detection.yml
num_classes: 2  # 原始 80  以二标签数据集为例(flame、smoke)# 如果按照上面的数据集制作数据集格式,则不需要修改内容,以个人数据集为例
img_folder: ./dataset/coco/train2017/
ann_file: ./dataset/coco/annotations/instances_train2017.json
2) configs/src/data/coco_dataset.py

mscoco_category2name修改为自己的标签内容

# 注意,0、1需要根据自己转化的COCO数据集确定的 我的是使用这个
mscoco_category2name = {0: 'flame',1: 'smoke'
}
# 按照80标签的COCO数据集是从1开始的
mscoco_category2name = {1: 'flame',2: 'smoke'
}
3)configs/src/core/yaml_utils.py

如果在训练的时候出现读取yaml文件,字体格式的报错,可将代码修改为如下。如果没有出现,可直接跳过本步骤。

# 添加encoding='utf-8'
with open(file_path, encoding='utf-8') as f:file_cfg = yaml.load(f, Loader=yaml.Loader)if file_cfg is None:return {}
4)configs/rtdeterv2/include/optimizer.yml

如果需要修改epoch,可以修改如下代码

epoches: 300

此时,对应的configs/rtdeterv2/include/dataloader.yml需要进行如下修改

train_dataloader: dataset: transforms:ops:- {type: RandomPhotometricDistort, p: 0.5}- {type: RandomZoomOut, fill: 0}- {type: RandomIoUCrop, p: 0.8}- {type: SanitizeBoundingBoxes, min_size: 1}- {type: RandomHorizontalFlip}- {type: Resize, size: [640, 640], }- {type: SanitizeBoundingBoxes, min_size: 1}- {type: ConvertPILImage, dtype: 'float32', scale: True}   - {type: ConvertBoxes, fmt: 'cxcywh', normalize: True}policy:name: stop_epochepoch: 299 # epoch in [71, ~) stop `ops`  原始为71ops: ['RandomPhotometricDistort', 'RandomZoomOut', 'RandomIoUCrop']collate_fn:type: BatchImageCollateFuncionscales: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800]stop_epoch: 299 # epoch in [71, ~) stop `multiscales` 原始为71

5. 代码训练、验证、以及模型参数和FLOPs

1) 训练
python tool/train.py -c ./configs/rtdetrv2/rtdetrv2_r50vd_m_7x_coco.yml --use-amp --seed=0
2)验证
python tool/train.py -c path/to/config -r path/to/checkpoint --test-only
http://www.xdnf.cn/news/207685.html

相关文章:

  • 【3D 地图】无人机测绘制作 3D 地图流程 ( 无人机采集数据 | 地图原始数据处理原理 | 数据处理软件 | 无人机测绘完整解决方案 )
  • 什么是静态住宅ip,跨境电商为什么要用静态住宅ip
  • IP属地是实时位置还是自己设置
  • SRIO IP调试问题记录(ready信号不拉高情况)
  • CentOS上搭建 Python 运行环境并使用第三方库
  • 【运维】还原 Docker 启动命令的利器:runlike 与 docker-autocompose
  • 数据结构---单链表的增删查改
  • Uniapp:设置页面下拉刷新
  • 1.1 点云数据获取方式——引言
  • Weka通过10天的内存指标数据计算内存指标动态阈值
  • 判断子序列
  • 问答:C++如何通过自定义实现移动构造函数和移动赋值运算符来实现rust的唯一所有权?
  • AI Agent开源技术栈
  • RabbitMQ 启动报错 “crypto.app“ 的解决方法
  • 项目三 - 任务2:创建笔记本电脑类(一爹多叔)
  • MySQL--数据引擎详解
  • gem5-gpu 安装过程碰到的问题记录 关于使用 Ruby + Garnet
  • Qt/C++开发监控GB28181系统/获取设备信息/设备配置参数/通道信息/设备状态
  • 当 AI 成为 “数字新物种”:人类职业的重构与进化
  • python:sklearn 决策树(Decision Tree)
  • 从 0 到 1:ComfyUI AI 工作流抠图构建全实践
  • Linux[配置vim]
  • 通信设备制造数字化转型中的创新模式与实践探索
  • 首页数据展示
  • 并发设计模式实战系列(9):消息传递(Message Passing)
  • Redis性能优化终极指南:从原理到实战的深度调优策略
  • 超越单体:进入微服务世界与Spring Cloud概述
  • Java Stream流
  • 【Fifty Project - D20】
  • 推荐系统实验指标置信度:p值核心原理与工程应用指南