当前位置: 首页 > news >正文

【SAM2代码解析】数据集处理3--混合数据加载器(DataLoader)

在这里插入图片描述

前情提要—trainer

展示了在训练过程中,数据是如何流动的

1)trainer的初始化

trainer = instantiate(cfg.trainer, _recursive_=False)
  • 传入的参数:
    在这里插入图片描述

    • data:training.dataset.sam2_datasets.TorchTrainMixedDataset
      在这里插入图片描述
    • model: training.model.sam2.SAM2Train
    • checkpoint:…
    • mode: train_only
    • optim: torch.optim.AdamW
    • loss: training.loss_fns.MultiStepMultiMasksAndIous
  • 配置信息赋值给实例变量
    在这里插入图片描述

  • 初始化其他配置信息…

  • self._setup_dataloaders() 设置数据加载器

    • 根据mode的格式实例化train_dataset,如下图所示(接7.1-1)的初始化模块):
      在这里插入图片描述

2)trainer.run

dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))-----接7.1-2)的get_loader方法

7. sam2_dataset.py

7.1 MixedDataLoader 类​

1)初始化

  • 传入参数
    在这里插入图片描述
    这里的传入参数就是前面trainer初始化中,data的配置信息。
  • 初始化信息
    • 属性赋值
    • 设置数据集的周期(??没看懂这样的意义)
    • sam允许训练时使用多个数据集
      • 计算每个数据集的采样概率,概率为子数据集量/全部数据集量,若只有一个数据集,那么采样概率为1

2)get_loader

    def get_loader(self, epoch) -> Iterable:# 初始化数据加载器列表dataloaders = []# 遍历数据集和批次大小for d_idx, (dataset, batch_size) in enumerate(zip(self.datasets, self.batch_sizes)):# 处理每个数据集# 如果每个周期的阶段数 self.phases_per_epoch 大于 1,则处理数据集的分块和设置周期。if self.phases_per_epoch > 1:# Major epoch that looops over entire dataset# len(main_epoch) == phases_per_epoch * len(epoch)# 计算主周期和局部阶段main_epoch = epoch // self.phases_per_epoch# Phase with in the main epochlocal_phase = epoch % self.phases_per_epoch# Start of new data-epoch or job is resumed after preemtion.if local_phase == 0 or self.chunks[d_idx] is None:# set seed for dataset epoch# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.self._set_dataset_epoch(dataset, main_epoch)# Separate random generator for subset samplingg = torch.Generator()g.manual_seed(main_epoch)self.chunks[d_idx] = torch.chunk(torch.randperm(len(dataset), generator=g),self.phases_per_epoch,)dataset = Subset(dataset, self.chunks[d_idx][local_phase])# 如果是新的数据周期或工作在中断后恢复,则设置数据集的周期,并为数据集创建随机分块。else:self._set_dataset_epoch(dataset, epoch)# 创建DistributedSampler采样器,用于在分布式环境中对数据集进行采样sampler = DistributedSampler(dataset, shuffle=self.shuffle)sampler.set_epoch(epoch)# 创建 BatchSampler 对象,用于从采样器中按批次大小采样数据。batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)# 创建数据加载器dataloaders.append(DataLoader(dataset,num_workers=self.num_workers,pin_memory=self.pin_memory,batch_sampler=batch_sampler,collate_fn=self.collate_fn,worker_init_fn=self.worker_init_fn,))# 返回混合数据加载器return MixedDataLoader(dataloaders, self.dataset_prob)
http://www.xdnf.cn/news/239599.html

相关文章:

  • 中国县级2m精度耕地分布数据(2020年)
  • 深度学习概述
  • Silo 科学数据工具库安装与使用指南
  • 【closerAI ComfyUI】开源社区炸锅!comfyUI原生支持Step1X-Edit 图像编辑!离简单免费高效又进一步
  • 关键词排名工具查到的位置和真实搜索差距大是什么原因?
  • SpringBoot优雅关机
  • MicroPython 开发ESP32应用教程 之 ADC及应用实例:电池电量检测并显示
  • HarmonyOS NEXT应用开发-Notification Kit(用户通知服务)notificationManager.cancelAll
  • ComfyUI
  • 国标GB28181平台EasyGBS未来研发方向在哪?
  • 数字中国开新篇,数智化为何需要新引擎
  • SLAM中的状态估计理论:从基础到前沿的完整解析
  • C++初阶:类和对象(二)
  • 机器学习|通过线性回归了解算法流程
  • spring 面试题
  • 智能 + 安全:婴幼儿托育管理实训基地标准化建设方案
  • 【LLM】MOE混合专家大模型综述(重要模块原理)
  • AI中常用概念的理解
  • w313安康学院新型冠状病毒肺炎疫情防控专题网站设计与实现
  • 【python实用小脚本-43】用Python自动发送生日祝福,让情感更高效
  • 架构进阶:72页集管IT基础设施蓝图设计方案【附全文阅读】
  • Nautilus侧栏没有桌面
  • 通过Yoast设置SEO标题不生效
  • OpenCV学习笔记(完)
  • Linux -- 操作系统
  • dubbo泛化调用时transient字段失效问题
  • 什么是基尔霍夫第一定律
  • 【python】-基础语法3
  • Semtech公司简介以及主流产品
  • C++继承(下)