当前位置: 首页 > news >正文

Ultralytics中的YOLODataset和BaseDataset

YOLODataset 和 BaseDataset  是 Ultralytics YOLO 框架中用于加载和处理数据集的两个关键类。

YOLODataset类(ultralytics/data/dataset.py)继承于 BaseDataset类(ultralytics/data/base.py)

 BaseDataset()

BaseDataset 是一个基础数据集类,提供了加载图像、缓存数据、预处理数据等核心功能。它是所有数据集类的父类,为子类提供了通用的数据加载和处理逻辑。

主要功能:

  • 图像加载:从指定路径加载图像文件。
  • 缓存机制:支持将图像缓存到内存或磁盘,以加速训练。
  • 数据预处理:包括图像大小调整、填充等操作。
  • 标签处理:加载和更新标签信息。
  • 数据增强:通过 build_transforms 方法支持数据增强。

关键方法:

  • get_img_files:从指定路径加载图像文件。
  • load_image:加载单个图像并返回其原始和调整后的尺寸。
  • cache_images:将图像缓存到内存或磁盘。
  • set_rectangle:设置矩形训练模式。
  • get_image_and_label:获取图像和标签信息。
  • build_transforms:构建数据增强和预处理管道(需子类实现)。

YOLODataset()

YOLODataset 继承自 BaseDataset,并在此基础上扩展了 YOLO 特定任务的功能。它支持 YOLO 的目标检测、实例分割、姿态估计和旋转框(OBB)等任务。

主要功能:

  • 标签加载:从磁盘或缓存中加载 YOLO 格式的标签。
  • 任务支持:根据任务类型(检测、分割、姿态估计、OBB)加载相应的标签。
  • 数据增强:扩展了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。
  • 标签格式处理:将标签转换为 YOLO 训练所需的格式。

关键方法:

  • cache_labels:缓存标签并检查图像和标签的完整性。
  • get_labels:加载标签并返回 YOLO 训练所需的格式。
  • build_transforms:构建 YOLO 特定的数据增强和预处理管道。
  • update_labels_info:更新标签格式以适应不同任务。
  • collate_fn:将数据样本整理为批次。

YOLODataset 和BaseDataset 的关系

YOLODataset 继承自 BaseDataset,并在此基础上扩展了 YOLO 特定任务的功能。具体关系如下:

class YOLODataset(BaseDataset):def __init__(self, *args, data=None, task="detect", **kwargs):super().__init__(*args, channels=self.data["channels"], **kwargs)
  • YOLODataset 通过 super().__init__() 调用 BaseDataset 的初始化方法,继承了 BaseDataset 的所有属性和方法。
  • YOLODataset 在初始化时增加了 task 参数,用于指定任务类型(检测、分割、姿态估计、OBB)。

方法重写

  • 标签加载YOLODataset 重写了 get_labels 方法,支持从 YOLO 格式的标签文件中加载数据。
  • 数据增强YOLODataset 重写了 build_transforms 方法,增加了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。
  • 任务支持YOLODataset 根据任务类型加载相应的标签,并处理成 YOLO 训练所需的格式,update_labels_info。

YOLODataset整体流程

1. 初始化 (__init__)

步骤:
1. 设置任务类型

2. 调用父类初始化

  • 通过 super().__init__(*args, channels=self.data["channels"], **kwargs) 调用父类 BaseDataset 的初始化函数。
  • 父类会加载图像文件、标签文件,并进行缓存和预处理。

 父类初始化流程:

1. 初始化参数

2. 加载图像文件 (get_img_files),返回找到的图像文件路径列表。

3. 加载标签文件 (get_labels),get_labels函数子类YOLODataset进行了复写!所以调用的是子类方法。 

4. 缓存图像 (cache_images),以加快训练速度。

5. 构建数据增强方法 (build_transforms)

 2. 缓存标签 (cache_labels)

cache_labels 方法用于缓存标签数据,并检查图像和标签的完整性。

 3. 获取标签 (get_labels)

get_labels 方法用于从缓存或磁盘加载标签数据,并准备用于训练。

下面是YOLODataset复写的函数。

    def get_labels(self):"""Returns dictionary of labels for YOLO training.This method loads labels from disk or cache, verifies their integrity, and prepares them for training.Returns:(List[dict]): List of label dictionaries, each containing information about an image and its annotations."""self.label_files = img2label_paths(self.im_files)cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")try:cache, exists = load_dataset_cache_file(cache_path), True  # attempt to load a *.cache fileassert cache["version"] == DATASET_CACHE_VERSION  # matches current versionassert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hashexcept (FileNotFoundError, AssertionError, AttributeError):cache, exists = self.cache_labels(cache_path), False  # run cache ops# Display cachenf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, totalif exists and LOCAL_RANK in {-1, 0}:d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"TQDM(None, desc=self.prefix + d, total=n, initial=n)  # display resultsif cache["msgs"]:LOGGER.info("\n".join(cache["msgs"]))  # display warnings# Read cache[cache.pop(k) for k in ("hash", "version", "msgs")]  # remove itemslabels = cache["labels"]if not labels:LOGGER.warning(f"No images found in {cache_path}, training may not work correctly. {HELP_URL}")self.im_files = [lb["im_file"] for lb in labels]  # update im_files# Check if the dataset is all boxes or all segmentslengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))if len_segments and len_boxes != len_segments:LOGGER.warning(f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ""To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")for lb in labels:lb["segments"] = []if len_cls == 0:LOGGER.warning(f"No labels found in {cache_path}, training may not work correctly. {HELP_URL}")return labels

  方法流程:
1. 获取标签文件路径:调用 img2label_paths 方法,根据图像文件路径生成对应的标签文件路径。

self.label_files = img2label_paths(self.im_files)

2. 加载缓存文件:尝试从缓存文件加载标签数据。

cache, exists = load_dataset_cache_file(cache_path), True

3. 验证缓存文件:检查缓存文件的版本和哈希值是否匹配。

assert cache["version"] == DATASET_CACHE_VERSION
assert cache["hash"] == get_hash(self.label_files + self.im_files)

4. 更新图像文件列表:从缓存中提取图像文件路径,并更新 self.im_files

self.im_files = [lb["im_file"] for lb in labels]

5. 检查标签完整性:检查标签数据是否完整(如边界框和分割掩码的数量是否一致)。

if len_segments and len_boxes != len_segments:LOGGER.warning("Box and segment counts should be equal...")

6. 返回标签数据:返回处理后的标签数据,供训练使用。

return labels

get_labels 方法的主要功能包括:

  1. 加载标签文件:从磁盘或缓存中加载标签数据。
  2. 验证标签完整性:检查标签数据是否有效(如是否存在、是否损坏等)。
  3. 准备标签数据:将标签数据转换为模型训练所需的格式。
  4. 返回标签数据:返回处理后的标签数据,供训练使用。

get_labels 方法的返回值是一个包含标签数据的列表,列表中的每个元素是一个字典,表示一张图像的标签信息。每个字典的键值对如下:

键名描述
im_file图像文件的路径。
shape图像的原始尺寸(高度、宽度)。
cls目标的类别索引,形状为 (n, 1),其中 n 是目标的数量。
bboxes目标的边界框,形状为 (n, 4),格式为 [x_center, y_center, width, height]
segments目标的分割掩码(如果任务为分割),形状为 (n, k, 2),其中 k 是点的数量。
keypoints目标的关键点(如果任务为姿态估计),形状为 (n, k, 3),其中 k 是点的数量。
normalized布尔值,表示边界框和关键点是否已归一化。
bbox_format边界框的格式(如 "xywh")。
输出示例:
[{"im_file": "path/to/images/image1.jpg","shape": (640, 640),"cls": [[0], [1]],  # 两个目标,类别分别为 0 和 1"bboxes": [[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]],  # 两个目标的边界框"segments": [],  # 分割掩码(如果任务为分割)"keypoints": [],  # 关键点(如果任务为姿态估计)"normalized": True,"bbox_format": "xywh"},{"im_file": "path/to/images/image2.jpg","shape": (640, 640),"cls": [[0]],  # 一个目标,类别为 0"bboxes": [[0.3, 0.4, 0.1, 0.2]],  # 一个目标的边界框"segments": [],  # 分割掩码(如果任务为分割)"keypoints": [],  # 关键点(如果任务为姿态估计)"normalized": True,"bbox_format": "xywh"}
]

4. 构建数据增强 (build_transforms)

build_transforms 方法用于构建数据增强和预处理管道。

5. 更新标签信息 (update_labels_info)

update_labels_info 方法用于更新标签数据的格式,以支持不同任务(如目标检测、分割等)。

 6. 数据加载与训练

在训练过程中,YOLODataset 会通过 __getitem__ 方法加载数据,并应用数据增强和预处理。


数据打包流程(collate_fn)

YOLODataset 中,数据打包的过程是通过 collate_fn 方法实现的。collate_fn 方法将多个样本(每张图像及其标签)打包成一个批次,以便输入到模型中进行训练。

collate_fn 方法的主要功能是将多个样本的数据(如图像、标签、边界框等)打包成一个批次。

1. 初始化批次字典:创建一个空字典 new_batch,用于存储批次数据。

new_batch = {}

2. 排序样本数据:确保每个样本的键值顺序一致,以便后续处理。

batch = [dict(sorted(b.items())) for b in batch]

3. 提取样本键值:获取所有样本的键(如 imgbboxescls 等),并将对应的值打包成列表。

keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch]))

4. 处理不同类型的数据:根据数据类型(如图像、边界框、类别等),使用不同的方式将数据打包成张量。

for i, k in enumerate(keys):value = values[i]if k in {"img", "text_feats"}:value = torch.stack(value, 0)  # 图像数据直接堆叠elif k == "visuals":value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)  # 填充序列elif k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:value = torch.cat(value, 0)  # 标签数据拼接new_batch[k] = value

5. 处理批次索引:为每个样本添加批次索引,以便在训练过程中区分不同样本。

new_batch["batch_idx"] = list(new_batch["batch_idx"])
for i in range(len(new_batch["batch_idx"])):new_batch["batch_idx"][i] += i  # 添加目标图像索引
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)

6. 返回批次数据:返回打包好的批次数据。

return new_batch

DataLoader中读取一批次数据的具体内容

在使用 DataLoader 加载数据时,collate_fn 方法会自动将多个样本打包成一个批次。以下是 DataLoader 中读取一批次数据的具体内容。

字段名描述
img图像数据,形状为 (batch_size, channels, height, width)
bboxes边界框数据,形状为 (total_objects, 4),格式为 [x_center, y_center, width, height]
cls类别数据,形状为 (total_objects, 1)
segments分割掩码数据,形状为 (total_objects, k, 2),其中 k 是点的数量。
keypoints关键点数据,形状为 (total_objects, k, 3),其中 k 是点的数量。
batch_idx批次索引,形状为 (total_objects,),用于区分不同样本的目标。
从 DataLoader 中读取一批次数据的示例输出:
{"img": tensor([[[[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]],  # 图像数据[[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]]),"bboxes": tensor([[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1],  # 边界框数据[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]),"cls": tensor([[0], [1], [0], [1]]),  # 类别数据"segments": tensor([], dtype=torch.float32),  # 分割掩码数据"keypoints": tensor([], dtype=torch.float32),  # 关键点数据"batch_idx": tensor([0, 0, 1, 1])  # 批次索引
}

 

http://www.xdnf.cn/news/331615.html

相关文章:

  • comfyui 实现中文提示词翻译英文进行图像生成
  • 低成本监控IPC模组概述
  • D盘出现不知名文件
  • int (*)[3]和int (*arr_ptr)[3]区别
  • Spark应用部署模式实例
  • 个人网站versionI正式上线了!Personal Website for Jing Liu
  • ✍️【TS类型体操进阶】挑战类型极限,成为类型魔法师!♂️✨
  • JAVA八股文
  • CI/CD与DevOps流程流程简述(提供思路)
  • 使用pdm管理python项目时去哪里找nuitka
  • 如何通过复盘提升团队能力?
  • 数组和集合
  • 【C++的类型转换】
  • 【漏洞预警】:致远OA V8.1 SP2 data.htm DOM型XSS漏洞
  • 使用 `detach()` 断开与共享特征层的连接
  • (已完结)完美解决C盘拓展卷是灰色的无法扩容的问题以及如何正确地在WINDOS上从一个盘扩容到C盘
  • Android 如何理解 Java JNI 中的引用与 Java 对象应用的区别
  • java算法的核心思想及考察的解题思路
  • Codeforces Round 1022 (Div. 2)
  • YOLOv1:开创实时目标检测新纪元
  • go.mod没有自动缓存问题
  • vue截图-html2canvas
  • 《硬件视界》专栏介绍(持续更新ing)
  • Qt学习Day2:信号槽
  • 从SQL的执行流程彻底详解预编译是如何解决SQL注入问题
  • Linux57配置MYSQL YUM源
  • 离散化(竞赛)
  • MinIo安装和使用操作说明(windows)
  • C++相关学习过程
  • 《USB技术应用与开发》第七讲:CDC串口设备案例