【SAM2代码解析】数据集处理3--混合数据加载器(DataLoader)
前情提要—trainer
展示了在训练过程中,数据是如何流动的
1)trainer的初始化
trainer = instantiate(cfg.trainer, _recursive_=False)
-
传入的参数:
- data:training.dataset.sam2_datasets.TorchTrainMixedDataset
- model: training.model.sam2.SAM2Train
- checkpoint:…
- mode: train_only
- optim: torch.optim.AdamW
- loss: training.loss_fns.MultiStepMultiMasksAndIous
- data:training.dataset.sam2_datasets.TorchTrainMixedDataset
-
配置信息赋值给实例变量
-
初始化其他配置信息…
-
self._setup_dataloaders() 设置数据加载器
- 根据mode的格式实例化train_dataset,如下图所示(
接7.1-1)的初始化模块
):
- 根据mode的格式实例化train_dataset,如下图所示(
2)trainer.run
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))-----接7.1-2)的get_loader方法
7. sam2_dataset.py
7.1 MixedDataLoader 类
1)初始化
- 传入参数
这里的传入参数就是前面trainer初始化中,data的配置信息。 - 初始化信息
- 属性赋值
- 设置数据集的周期(??没看懂这样的意义)
- sam允许训练时使用多个数据集
- 计算每个数据集的采样概率,概率为子数据集量/全部数据集量,若只有一个数据集,那么采样概率为1
2)get_loader
def get_loader(self, epoch) -> Iterable:# 初始化数据加载器列表dataloaders = []# 遍历数据集和批次大小for d_idx, (dataset, batch_size) in enumerate(zip(self.datasets, self.batch_sizes)):# 处理每个数据集# 如果每个周期的阶段数 self.phases_per_epoch 大于 1,则处理数据集的分块和设置周期。if self.phases_per_epoch > 1:# Major epoch that looops over entire dataset# len(main_epoch) == phases_per_epoch * len(epoch)# 计算主周期和局部阶段main_epoch = epoch // self.phases_per_epoch# Phase with in the main epochlocal_phase = epoch % self.phases_per_epoch# Start of new data-epoch or job is resumed after preemtion.if local_phase == 0 or self.chunks[d_idx] is None:# set seed for dataset epoch# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.self._set_dataset_epoch(dataset, main_epoch)# Separate random generator for subset samplingg = torch.Generator()g.manual_seed(main_epoch)self.chunks[d_idx] = torch.chunk(torch.randperm(len(dataset), generator=g),self.phases_per_epoch,)dataset = Subset(dataset, self.chunks[d_idx][local_phase])# 如果是新的数据周期或工作在中断后恢复,则设置数据集的周期,并为数据集创建随机分块。else:self._set_dataset_epoch(dataset, epoch)# 创建DistributedSampler采样器,用于在分布式环境中对数据集进行采样sampler = DistributedSampler(dataset, shuffle=self.shuffle)sampler.set_epoch(epoch)# 创建 BatchSampler 对象,用于从采样器中按批次大小采样数据。batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)# 创建数据加载器dataloaders.append(DataLoader(dataset,num_workers=self.num_workers,pin_memory=self.pin_memory,batch_sampler=batch_sampler,collate_fn=self.collate_fn,worker_init_fn=self.worker_init_fn,))# 返回混合数据加载器return MixedDataLoader(dataloaders, self.dataset_prob)