Day 38: Dataset类和DataLoader类
核心概念
在处理大规模数据集时,显存往往无法一次性存储所有数据,因此需要使用分批训练的方法。PyTorch提供了两个关键类来解决这个问题:
- DataLoader类:决定数据如何加载
- Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理
实战演练:MNIST数据集
1. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
2. 数据预处理
# 数据预处理管道
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的标准化参数
])
3. 加载MNIST数据集
# 加载训练集
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)# 加载测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)
🔧 Dataset类详解
Dataset类的核心方法
PyTorch的torch.utils.data.Dataset
是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:
__len__()
:返回数据集的样本总数__getitem__(idx)
:根据索引idx返回对应样本的数据和标签
魔术方法示例
# __getitem__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
# __len__方法示例
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 使用len()函数获取元素数量,这会自动调用__len__方法
my_list_obj = MyList()
print(len(my_list_obj)) # 输出:5
查看单个样本
# 获取一个样本
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item()
image, label = train_dataset[sample_idx]
print(f"Label: {label}")# 可视化图像
def imshow(img):img = img * 0.3081 + 0.1307 # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')plt.show()imshow(image)
DataLoader类详解
DataLoader负责将Dataset中的数据按批次加载,并提供多种数据加载策略:
# 创建训练数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片shuffle=True # 随机打乱数据
)# 创建测试数据加载器
test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)
Dataset vs DataLoader 对比
维度 | Dataset | DataLoader |
---|---|---|
核心职责 | 定义"数据是什么"和"如何获取单个样本" | 定义"如何批量加载数据"和"加载策略" |
核心方法 | __getitem__ 、__len__ | 无自定义方法,通过参数控制 |
预处理位置 | 在__getitem__ 中通过transform 执行 | 无预处理逻辑 |
并行处理 | 无(仅单样本处理) | 支持多进程加载 |
典型参数 | root 、transform | batch_size 、shuffle 、num_workers |
总结
Dataset类的职责
- 数据内容定义:数据存储路径、读取方式
- 预处理逻辑:图像变换、数据增强等
- 返回格式:如
(image_tensor, label)
DataLoader类的职责
- 批量处理:控制batch_size
- 数据打乱:shuffle参数
- 并行加载:num_workers参数
- 内存管理:防止一次性加载过多数据
实用技巧
- batch_size选择:通常选择2的幂次方(32、64、128等),这与GPU计算效率相关
- 数据预处理时机:在Dataset的
__getitem__
方法中进行,而不是DataLoader中 - 内存优化:DataLoader的num_workers参数可以开启多进程加载,提高效率
@浙大疏锦行