当前位置: 首页 > backend >正文

Day 38: Dataset类和DataLoader类

核心概念

在处理大规模数据集时,显存往往无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch提供了两个关键类来解决这个问题:

  1. DataLoader类:决定数据如何加载
  2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理

实战演练:MNIST数据集

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)

2. 数据预处理

# 数据预处理管道
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的标准化参数
])

3. 加载MNIST数据集

# 加载训练集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)# 加载测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

🔧 Dataset类详解

Dataset类的核心方法

PyTorch的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

  • __len__():返回数据集的样本总数
  • __getitem__(idx):根据索引idx返回对应样本的数据和标签

魔术方法示例

# __getitem__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30
# __len__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 使用len()函数获取元素数量,这会自动调用__len__方法
my_list_obj = MyList()
print(len(my_list_obj))  # 输出:5

查看单个样本

# 获取一个样本
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx]
print(f"Label: {label}")# 可视化图像
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')plt.show()imshow(image)

DataLoader类详解

DataLoader负责将Dataset中的数据按批次加载,并提供多种数据加载策略:

# 创建训练数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,    # 每个批次64张图片shuffle=True      # 随机打乱数据
)# 创建测试数据加载器
test_loader = DataLoader(test_dataset,batch_size=1000   # 每个批次1000张图片# shuffle=False   # 测试时不需要打乱数据
)

Dataset vs DataLoader 对比

维度DatasetDataLoader
核心职责定义"数据是什么"和"如何获取单个样本"定义"如何批量加载数据"和"加载策略"
核心方法__getitem____len__无自定义方法,通过参数控制
预处理位置__getitem__中通过transform执行无预处理逻辑
并行处理无(仅单样本处理)支持多进程加载
典型参数roottransformbatch_sizeshufflenum_workers

总结

Dataset类的职责

  • 数据内容定义:数据存储路径、读取方式
  • 预处理逻辑:图像变换、数据增强等
  • 返回格式:如(image_tensor, label)

DataLoader类的职责

  • 批量处理:控制batch_size
  • 数据打乱:shuffle参数
  • 并行加载:num_workers参数
  • 内存管理:防止一次性加载过多数据

实用技巧

  1. batch_size选择:通常选择2的幂次方(32、64、128等),这与GPU计算效率相关
  2. 数据预处理时机:在Dataset的__getitem__方法中进行,而不是DataLoader中
  3. 内存优化:DataLoader的num_workers参数可以开启多进程加载,提高效率

@浙大疏锦行

http://www.xdnf.cn/news/17587.html

相关文章:

  • 计算机网络摘星题库800题笔记 第5章 传输层
  • 达梦数据闪回查询-快速恢复表
  • 燕山大学计算机网络实验(2025最新)
  • SpringMVC的原理及执行流程?
  • uv 配置和简单使用
  • 飞算JavaAI全流程实操指南:从需求到部署的智能开发体验
  • 虚拟机高级玩法-网页也能运行虚拟机——WebAssembly
  • code-inspector-plugin插件
  • [ue5 shader] 路由申明和路由引用
  • 【SpringBoot】05 容器功能 - SpringBoot底层注解的应用与实战 - @Configuration + @Bean
  • 智能家居Agent:物联网设备的统一控制与管理
  • 无人机航拍数据集|第13期 无人机城市斑马线目标检测YOLO数据集963张yolov11/yolov8/yolov5可训练
  • 无人机智能返航模块技术分析
  • 无人机航拍数据集|第14期 无人机水体污染目标检测YOLO数据集3000张yolov11/yolov8/yolov5可训练
  • k8s-scheduler 解析
  • 让齿轮与斑马线共舞:汽车文化驿站及安全教育基地的展陈实践
  • 【工作笔记】win11系统docker desktop配置国内mirror不生效解决方案汇总整理
  • 7 种最佳 DBAN 替代方案,彻底擦除硬盘数据
  • 【实时Linux实战系列】实时环境监测系统架构设计
  • 思科、华为、华三如何切换三层端口?
  • 初识数据结构——优先级队列(堆!堆!堆!)
  • Java静态代理和动态代理
  • [SC]SystemC中的SC_FORK和SC_JOIN用法详细介绍
  • mysql登录失败 ERROR1698
  • Java多线程基础总结
  • Camera open failed
  • STM32学习笔记7-TIM输入捕获模式
  • MySQL-日志
  • JavaScript Const的基础使用
  • UE 手柄点击UI 事件