Ultralytics中的YOLODataset和BaseDataset
YOLODataset 和 BaseDataset 是 Ultralytics YOLO 框架中用于加载和处理数据集的两个关键类。
YOLODataset类(ultralytics/data/dataset.py)继承于 BaseDataset类(ultralytics/data/base.py)
BaseDataset()
BaseDataset
是一个基础数据集类,提供了加载图像、缓存数据、预处理数据等核心功能。它是所有数据集类的父类,为子类提供了通用的数据加载和处理逻辑。
主要功能:
- 图像加载:从指定路径加载图像文件。
- 缓存机制:支持将图像缓存到内存或磁盘,以加速训练。
- 数据预处理:包括图像大小调整、填充等操作。
- 标签处理:加载和更新标签信息。
- 数据增强:通过
build_transforms
方法支持数据增强。
关键方法:
get_img_files
:从指定路径加载图像文件。load_image
:加载单个图像并返回其原始和调整后的尺寸。cache_images
:将图像缓存到内存或磁盘。set_rectangle
:设置矩形训练模式。get_image_and_label
:获取图像和标签信息。build_transforms
:构建数据增强和预处理管道(需子类实现)。
YOLODataset()
YOLODataset
继承自 BaseDataset
,并在此基础上扩展了 YOLO 特定任务的功能。它支持 YOLO 的目标检测、实例分割、姿态估计和旋转框(OBB)等任务。
主要功能:
- 标签加载:从磁盘或缓存中加载 YOLO 格式的标签。
- 任务支持:根据任务类型(检测、分割、姿态估计、OBB)加载相应的标签。
- 数据增强:扩展了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。
- 标签格式处理:将标签转换为 YOLO 训练所需的格式。
关键方法:
cache_labels
:缓存标签并检查图像和标签的完整性。get_labels
:加载标签并返回 YOLO 训练所需的格式。build_transforms
:构建 YOLO 特定的数据增强和预处理管道。update_labels_info
:更新标签格式以适应不同任务。collate_fn
:将数据样本整理为批次。
YOLODataset 和BaseDataset 的关系
YOLODataset
继承自 BaseDataset
,并在此基础上扩展了 YOLO 特定任务的功能。具体关系如下:
class YOLODataset(BaseDataset):def __init__(self, *args, data=None, task="detect", **kwargs):super().__init__(*args, channels=self.data["channels"], **kwargs)
YOLODataset
通过super().__init__()
调用BaseDataset
的初始化方法,继承了BaseDataset
的所有属性和方法。YOLODataset
在初始化时增加了task
参数,用于指定任务类型(检测、分割、姿态估计、OBB)。
方法重写
- 标签加载:
YOLODataset
重写了get_labels
方法,支持从 YOLO 格式的标签文件中加载数据。- 数据增强:
YOLODataset
重写了build_transforms
方法,增加了 YOLO 特定的数据增强逻辑(如 Mosaic、MixUp 等)。- 任务支持:
YOLODataset
根据任务类型加载相应的标签,并处理成 YOLO 训练所需的格式,update_labels_info。
YOLODataset整体流程
1. 初始化 (__init__
)
步骤:
1. 设置任务类型2. 调用父类初始化
- 通过
super().__init__(*args, channels=self.data["channels"], **kwargs)
调用父类BaseDataset
的初始化函数。- 父类会加载图像文件、标签文件,并进行缓存和预处理。
父类初始化流程:
1. 初始化参数。
2. 加载图像文件 (
get_img_files
),返回找到的图像文件路径列表。3. 加载标签文件 (
get_labels
),get_labels函数子类YOLODataset进行了复写!所以调用的是子类方法。4. 缓存图像 (
cache_images
),以加快训练速度。5. 构建数据增强方法 (
build_transforms
)。
2. 缓存标签 (cache_labels
)
cache_labels
方法用于缓存标签数据,并检查图像和标签的完整性。
3. 获取标签 (get_labels
)
get_labels
方法用于从缓存或磁盘加载标签数据,并准备用于训练。
下面是YOLODataset复写的函数。
def get_labels(self):"""Returns dictionary of labels for YOLO training.This method loads labels from disk or cache, verifies their integrity, and prepares them for training.Returns:(List[dict]): List of label dictionaries, each containing information about an image and its annotations."""self.label_files = img2label_paths(self.im_files)cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")try:cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache fileassert cache["version"] == DATASET_CACHE_VERSION # matches current versionassert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hashexcept (FileNotFoundError, AssertionError, AttributeError):cache, exists = self.cache_labels(cache_path), False # run cache ops# Display cachenf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, totalif exists and LOCAL_RANK in {-1, 0}:d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"TQDM(None, desc=self.prefix + d, total=n, initial=n) # display resultsif cache["msgs"]:LOGGER.info("\n".join(cache["msgs"])) # display warnings# Read cache[cache.pop(k) for k in ("hash", "version", "msgs")] # remove itemslabels = cache["labels"]if not labels:LOGGER.warning(f"No images found in {cache_path}, training may not work correctly. {HELP_URL}")self.im_files = [lb["im_file"] for lb in labels] # update im_files# Check if the dataset is all boxes or all segmentslengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))if len_segments and len_boxes != len_segments:LOGGER.warning(f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ""To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")for lb in labels:lb["segments"] = []if len_cls == 0:LOGGER.warning(f"No labels found in {cache_path}, training may not work correctly. {HELP_URL}")return labels
方法流程:
1. 获取标签文件路径:调用img2label_paths
方法,根据图像文件路径生成对应的标签文件路径。self.label_files = img2label_paths(self.im_files)
2. 加载缓存文件:尝试从缓存文件加载标签数据。
cache, exists = load_dataset_cache_file(cache_path), True
3. 验证缓存文件:检查缓存文件的版本和哈希值是否匹配。
assert cache["version"] == DATASET_CACHE_VERSION assert cache["hash"] == get_hash(self.label_files + self.im_files)
4. 更新图像文件列表:从缓存中提取图像文件路径,并更新
self.im_files
。self.im_files = [lb["im_file"] for lb in labels]
5. 检查标签完整性:检查标签数据是否完整(如边界框和分割掩码的数量是否一致)。
if len_segments and len_boxes != len_segments:LOGGER.warning("Box and segment counts should be equal...")
6. 返回标签数据:返回处理后的标签数据,供训练使用。
return labels
get_labels
方法的主要功能包括:
- 加载标签文件:从磁盘或缓存中加载标签数据。
- 验证标签完整性:检查标签数据是否有效(如是否存在、是否损坏等)。
- 准备标签数据:将标签数据转换为模型训练所需的格式。
- 返回标签数据:返回处理后的标签数据,供训练使用。
get_labels
方法的返回值是一个包含标签数据的列表,列表中的每个元素是一个字典,表示一张图像的标签信息。每个字典的键值对如下:
键名 | 描述 |
---|---|
im_file | 图像文件的路径。 |
shape | 图像的原始尺寸(高度、宽度)。 |
cls | 目标的类别索引,形状为 (n, 1) ,其中 n 是目标的数量。 |
bboxes | 目标的边界框,形状为 (n, 4) ,格式为 [x_center, y_center, width, height] 。 |
segments | 目标的分割掩码(如果任务为分割),形状为 (n, k, 2) ,其中 k 是点的数量。 |
keypoints | 目标的关键点(如果任务为姿态估计),形状为 (n, k, 3) ,其中 k 是点的数量。 |
normalized | 布尔值,表示边界框和关键点是否已归一化。 |
bbox_format | 边界框的格式(如 "xywh" )。 |
输出示例:
[{"im_file": "path/to/images/image1.jpg","shape": (640, 640),"cls": [[0], [1]], # 两个目标,类别分别为 0 和 1"bboxes": [[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]], # 两个目标的边界框"segments": [], # 分割掩码(如果任务为分割)"keypoints": [], # 关键点(如果任务为姿态估计)"normalized": True,"bbox_format": "xywh"},{"im_file": "path/to/images/image2.jpg","shape": (640, 640),"cls": [[0]], # 一个目标,类别为 0"bboxes": [[0.3, 0.4, 0.1, 0.2]], # 一个目标的边界框"segments": [], # 分割掩码(如果任务为分割)"keypoints": [], # 关键点(如果任务为姿态估计)"normalized": True,"bbox_format": "xywh"}
]
4. 构建数据增强 (
build_transforms
)
build_transforms
方法用于构建数据增强和预处理管道。
5. 更新标签信息 (update_labels_info
)
update_labels_info
方法用于更新标签数据的格式,以支持不同任务(如目标检测、分割等)。
6. 数据加载与训练
在训练过程中,YOLODataset
会通过 __getitem__
方法加载数据,并应用数据增强和预处理。
数据打包流程(collate_fn)
在 YOLODataset
中,数据打包的过程是通过 collate_fn
方法实现的。collate_fn
方法将多个样本(每张图像及其标签)打包成一个批次,以便输入到模型中进行训练。
collate_fn
方法的主要功能是将多个样本的数据(如图像、标签、边界框等)打包成一个批次。
1. 初始化批次字典:创建一个空字典
new_batch
,用于存储批次数据。new_batch = {}
2. 排序样本数据:确保每个样本的键值顺序一致,以便后续处理。
batch = [dict(sorted(b.items())) for b in batch]
3. 提取样本键值:获取所有样本的键(如
img
、bboxes
、cls
等),并将对应的值打包成列表。keys = batch[0].keys() values = list(zip(*[list(b.values()) for b in batch]))
4. 处理不同类型的数据:根据数据类型(如图像、边界框、类别等),使用不同的方式将数据打包成张量。
for i, k in enumerate(keys):value = values[i]if k in {"img", "text_feats"}:value = torch.stack(value, 0) # 图像数据直接堆叠elif k == "visuals":value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True) # 填充序列elif k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:value = torch.cat(value, 0) # 标签数据拼接new_batch[k] = value
5. 处理批次索引:为每个样本添加批次索引,以便在训练过程中区分不同样本。
new_batch["batch_idx"] = list(new_batch["batch_idx"]) for i in range(len(new_batch["batch_idx"])):new_batch["batch_idx"][i] += i # 添加目标图像索引 new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
6. 返回批次数据:返回打包好的批次数据。
return new_batch
DataLoader中读取一批次数据的具体内容
在使用 DataLoader
加载数据时,collate_fn
方法会自动将多个样本打包成一个批次。以下是 DataLoader
中读取一批次数据的具体内容。
字段名 | 描述 |
---|---|
img | 图像数据,形状为 (batch_size, channels, height, width) 。 |
bboxes | 边界框数据,形状为 (total_objects, 4) ,格式为 [x_center, y_center, width, height] 。 |
cls | 类别数据,形状为 (total_objects, 1) 。 |
segments | 分割掩码数据,形状为 (total_objects, k, 2) ,其中 k 是点的数量。 |
keypoints | 关键点数据,形状为 (total_objects, k, 3) ,其中 k 是点的数量。 |
batch_idx | 批次索引,形状为 (total_objects,) ,用于区分不同样本的目标。 |
从 DataLoader 中读取一批次数据的示例输出:
{"img": tensor([[[[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1]], # 图像数据[[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]]),"bboxes": tensor([[0.5, 0.5, 0.2, 0.3], [0.7, 0.8, 0.1, 0.1], # 边界框数据[0.3, 0.4, 0.1, 0.2], [0.6, 0.7, 0.2, 0.1]]),"cls": tensor([[0], [1], [0], [1]]), # 类别数据"segments": tensor([], dtype=torch.float32), # 分割掩码数据"keypoints": tensor([], dtype=torch.float32), # 关键点数据"batch_idx": tensor([0, 0, 1, 1]) # 批次索引
}