Dataset和Dataloader
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
总结
维度 | Dataset | DataLoader |
---|---|---|
核心职责 | 定义“数据是什么”和“如何获取单个样本” | 定义“如何批量加载数据”和“加载策略” |
核心方法 | __getitem__ (获取单个样本)、__len__ (样本总数) | 无自定义方法,通过参数控制加载逻辑 |
预处理位置 | 在__getitem__ 中通过transform 执行预处理 | 无预处理逻辑,依赖Dataset 返回的预处理后数据 |
并行处理 | 无(仅单样本处理) | 支持多进程加载(num_workers>0 ) |
典型参数 | root (数据路径)、transform (预处理) | batch_size 、shuffle 、num_workers |
核心结论
-
Dataset
类:定义数据的内容和格式(即“如何获取单个样本”),包括:- 数据存储路径/来源(如文件路径、数据库查询)。
- 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。
- 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过
transform
参数实现)。 - 返回值格式(如
(image_tensor, label)
)。
-
DataLoader
类:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:- 批量大小(
batch_size
)。 - 是否打乱数据顺序(
shuffle
)。
- 批量大小(
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307 # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)
看看cifar数据库,将代码中数据集名称换为CIFAR10或CIFAR100即可。
@浙大疏锦行