当前位置: 首页 > ops >正文

day38 python Dataset和Dataloader

目录

一、背景知识

二、数据预处理与数据集加载

三、Dataset类:定义“数据是什么”和“如何获取单个样本”

1. __getitem__方法详解

2. __len__方法详解

3. 自定义MNIST数据集类

4. 可视化原始图像

四、DataLoader类:定义“如何批量加载数据”和“加载策略”

五、总结

1. Dataset类的核心要点

2. DataLoader类的核心要点

3. 两者的协同工作


一、背景知识

MNIST数据集是一个非常经典的数据集,包含60000张训练图片和10000张测试图片,每张图片大小为28×28像素,共包含10个类别(0到9的数字)。由于每个数据的维度比较小,既可以视为结构化数据,用机器学习、MLP(多层感知机)训练,也可以视为图像数据,用卷积神经网络训练。

在处理大规模数据集时,显存常常无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch的DataLoader类可以自动将数据集切分为多个批次(batch),并支持多线程加载数据,从而提高数据加载效率。而Dataset类则用于定义数据集的读取方式和预处理方式。

二、数据预处理与数据集加载

在开始之前,我们需要对数据进行预处理。PyTorch的transforms模块提供了一系列常用的图像预处理操作。以下是我们的预处理流程:

transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
])

接下来,我们加载MNIST数据集。如果没有下载过,datasets.MNIST会自动下载:

train_dataset = datasets.MNIST(root='./data',  # 数据存储路径train=True,  # 加载训练集download=True,  # 如果没有数据则自动下载transform=transform  # 应用预处理
)test_dataset = datasets.MNIST(root='./data',  # 数据存储路径train=False,  # 加载测试集transform=transform  # 应用预处理
)

这里需要注意的是,PyTorch的思路是在数据加载阶段就完成数据的预处理,这与我们通常的“先有数据集,后续再处理”的思路有所不同。

三、Dataset类:定义“数据是什么”和“如何获取单个样本”

torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:__len____getitem__

  • __len__方法:返回数据集的样本总数。

  • __getitem__方法:根据索引idx返回对应样本的数据和标签。

这两个方法是PyTorch对数据集的基本要求,只有实现了它们,数据集才能被DataLoader等工具兼容。这类似于一种接口约定,就像函数参数的规范一样。

在Python中,__getitem____len__是类的特殊方法(也叫魔术方法),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,从而赋予类特定的行为。

1. __getitem__方法详解

__getitem__方法用于让对象支持索引操作。当使用[]语法访问对象元素时,Python会自动调用该方法。例如:

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]my_list_obj = MyList()
print(my_list_obj[2])  # 输出:30

通过定义__getitem__方法,MyList类的实例能够像Python内置的列表一样使用索引获取元素。

2. __len__方法详解

__len__方法用于返回对象中元素的数量。当使用内置函数len()作用于对象时,Python会自动调用该方法。例如:

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)my_list_obj = MyList()
print(len(my_list_obj))  # 输出:5

这里定义的__len__方法,使得MyList类的实例可以像普通列表一样被len()函数调用获取长度。

3. 自定义MNIST数据集类

为了更好地理解Dataset类的使用,我们来实现一个简化版本的MNIST数据集类:

class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train)  # 假设 fetch_mnist_data 是一个函数self.transform = transform  # 预处理操作def __len__(self):return len(self.data)  # 返回样本总数def __getitem__(self, idx):# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理if self.transform is not None:img = self.transform(img)return img, target  # 返回处理后的图像和标签

在这个类中,__getitem__方法负责根据索引获取单个样本,并应用预处理操作(如ToTensorNormalize)。这就好比厨师在准备单个菜品时,会进行切菜、调味等预处理操作。

4. 可视化原始图像

为了查看数据集中的图像,我们可以定义一个可视化函数imshow,并随机选择一张图片进行展示:

def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')  # 显示灰度图像plt.show()sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()  # 随机选择一张图片的索引
image, label = train_dataset[sample_idx]  # 获取图片和标签
print(f"Label: {label}")
imshow(image)

四、DataLoader类:定义“如何批量加载数据”和“加载策略”

DataLoader类的职责是将Dataset中的数据批量加载出来,并支持多线程加载,从而提高数据加载效率。它的使用非常简单:

train_loader = DataLoader(train_dataset,batch_size=64,  # 每个批次64张图片shuffle=True  # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000  # 每个批次1000张图片
)

DataLoader类的主要参数包括:

  • dataset:要加载的数据集。

  • batch_size:每个批次的样本数量。

  • shuffle:是否随机打乱数据。

  • num_workers:加载数据时使用的子进程数量,默认为0(不使用多进程)。

DataLoader类可以看作是“服务员”,它将Dataset类准备好的“菜品”(单个样本)按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。

五、总结

通过以上内容的学习,我们可以对Dataset类和DataLoader类进行如下的总结:

维度DatasetDataLoader
核心职责定义“数据是什么”和“如何获取单个样本”定义“如何批量加载数据”和“加载策略”
核心方法__getitem__(获取单个样本)、__len__(样本总数)无自定义方法,通过参数控制加载逻辑
预处理位置__getitem__中通过transform执行预处理无预处理逻辑,依赖Dataset返回的预处理后数据
并行处理无(仅单样本处理)支持多进程加载(num_workers>0
典型参数root(数据路径)、transform(预处理)batch_sizeshufflenum_workers

1. Dataset类的核心要点

  • 定义数据的内容和格式:包括数据存储路径/来源、原始数据的读取方式、样本的预处理逻辑以及返回值格式。

  • 实现两个核心方法__len____getitem__,这是PyTorch对数据集的基本要求,也是与DataLoader兼容的关键。

2. DataLoader类的核心要点

  • 定义数据的加载方式和批量处理逻辑:通过batch_size控制每个批次的样本数量,通过shuffle决定是否随机打乱数据,通过num_workers设置多进程加载的子进程数量。

  • 依赖Dataset返回的预处理后数据DataLoader本身不负责预处理,而是直接使用Dataset返回的已经预处理好的数据。

3. 两者的协同工作

  • Dataset类是“厨师”,负责准备单个样本,包括数据的读取和预处理。

  • DataLoader类是“服务员”,负责将“厨师”准备好的单个样本按照订单(批量大小、是否打乱等策略)组合并上桌(批量加载)。

通过Dataset类和DataLoader类的协同工作,我们可以高效地处理和加载大规模数据集,为深度学习模型的训练提供有力支持。

@浙大疏锦行

http://www.xdnf.cn/news/9105.html

相关文章:

  • OpenCV CUDA模块图像处理------颜色空间处理之GPU 上交换图像的通道顺序函数swapChannels()
  • Wan2.1 图生视频模型内部协作流程
  • 02.【Qt开发】Qt Creator介绍及新建项目流程
  • Python打卡 DAY 38
  • 华为高斯数据库(GaussDB)深度解析:国产分布式数据库的旗舰之作
  • 局域协作中的前端调试:WebDebugX 在本地多端调试中的实践
  • CPU服务器的主要功能有哪些?
  • 高防CDN如何解决网站访问卡顿与崩溃问题?
  • VUE npm ERR! code ERESOLVE, npm ERR! ERESOLVE could not resolve, 错误有效解决
  • 鸿蒙仓颉开发语言实战教程:自定义组件
  • 将Windows11下的Ubuntu应用移动到其他盘
  • mysql中的MVCC
  • PH热榜 | 2025-05-24
  • DRF的使用
  • 【前端】【React】React性能优化系统总结
  • DAY07:Vue Router深度解析与多页面博客系统实战
  • 微信小程序的软件测试用例编写指南及示例
  • kafka SASL/PLAIN 认证及 ACL 权限控制
  • Mysql之用户管理
  • [25-cv-05718]BSF律所代理潮流品牌KAWS公仔(商标+版权)
  • 分布式项目保证消息幂等性的常见策略
  • 并发编程艺术--AQS底层源码解析(三)
  • 华为OD机试真题——构成正方形的数量(2025B卷:100分)Java/python/JavaScript/C++/C/GO六种最佳实现
  • P2340 [USACO03FALL] Cow Exhibition G
  • 时序模型上——ARIMA/
  • 云蝠 Voice Agent:开启大模型时代语音交互新纪元
  • AAOS系列之(四) ---APP端如何获取CarService中的各个服务代理
  • day8补充(中断驱动和队列缓冲实现高效数据处理)
  • day020-sed和find
  • 【C++高阶一】二叉搜索树