当前位置: 首页 > news >正文

PyTorch数据集与数据集加载

PyTorch中的Dataset与DataLoader详解

1. Dataset基础

Dataset是PyTorch中表示数据集的抽象类,我们需要继承它并实现两个关键方法:

from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):"""初始化方法,加载数据"""self.data = dataself.labels = labelsdef __len__(self):"""返回数据集的大小"""return len(self.data)def __getitem__(self, idx):"""根据索引获取单个样本"""sample = self.data[idx]label = self.labels[idx]return sample, label

使用示例

# 假设我们有一些简单的数据
data = [[1, 2], [3, 4], [5, 6], [7, 8]]
labels = [0, 1, 0, 1]# 创建数据集实例
dataset = CustomDataset(data, labels)# 测试数据集
print(f"数据集大小: {len(dataset)}")  # 输出: 4
print(dataset[0])  # 输出: ([1, 2], 0)

2. DataLoader功能

DataLoader负责从Dataset中加载数据,并提供批处理、打乱顺序和多线程加载等功能。

from torch.utils.data import DataLoader# 创建DataLoader
dataloader = DataLoader(dataset,          # 数据集对象batch_size=2,     # 每批数据大小shuffle=True,     # 是否打乱数据num_workers=2     # 使用多少子进程加载数据
)# 遍历数据
for batch_idx, (batch_data, batch_labels) in enumerate(dataloader):print(f"批次 {batch_idx}:")print("数据:", batch_data)print("标签:", batch_labels)

3. 实际应用示例

图像数据集示例

import os
from PIL import Imageclass ImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_dir = img_dirself.transform = transformself.img_names = os.listdir(img_dir)def __len__(self):return len(self.img_names)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_names[idx])image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)# 假设文件名格式为 "label_image.jpg"label = int(self.img_names[idx].split('_')[0])return image, label

使用数据增强

from torchvision import transforms# 定义数据转换
transform = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 创建数据集
dataset = ImageDataset("path/to/images", transform=transform)# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

4. 高级功能

自定义批处理

from torch.utils.data.dataloader import default_collatedef custom_collate(batch):# 过滤掉None样本batch = [item for item in batch if item is not None]if len(batch) == 0:return Nonereturn default_collate(batch)dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate)

使用Subset划分数据集

from torch.utils.data import random_split# 假设我们有一个大的数据集
full_dataset = CustomDataset(data, labels)# 划分训练集和测试集
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])# 创建对应的DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

5. 性能优化技巧

  1. num_workers设置:根据CPU核心数设置合理的num_workers值(通常2-4)
  2. pin_memory:在GPU训练时设置pin_memory=True可以加速数据传输
  3. 预取数据:使用prefetch_factor参数(PyTorch 1.7+)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,pin_memory=True,prefetch_factor=2
)

6. 常见问题解决

  1. 内存不足:减小batch_size或使用IterableDataset
  2. 数据加载慢:确保数据存储在SSD上,使用更快的文件格式(如HDF5)
  3. 数据不平衡:使用WeightedRandomSampler
from torch.utils.data import WeightedRandomSampler# 假设我们有不平衡的数据集
weights = [1.0 if label == 0 else 0.1 for _, label in dataset]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)balanced_loader = DataLoader(dataset, batch_size=32, sampler=sampler)

通过合理使用Dataset和DataLoader,可以高效地管理和加载大规模数据集,为深度学习模型训练提供稳定、高效的数据管道。

http://www.xdnf.cn/news/274069.html

相关文章:

  • ICCV2023 | 视觉Transformer的Token-标签对齐
  • window-docker的容器使用宿主机音频设备
  • 深入探索 Java 区块链技术:从核心原理到企业级实践
  • nginx 核心功能 02
  • 【项目篇之统一硬盘操作】仿照RabbitMQ模拟实现消息队列
  • C++入门小馆:继承
  • 数据库-数据类型,表的约束和基本查询操作
  • SONiC-OTN代码详解(具体内容待续)
  • set autotrace报错
  • K8S的使用(部署pod\service)+安装kubesphere图形化界面使用和操作
  • 【机器学习案列-22】基于线性回归(LR)的手机发布价格预测
  • 【iOS】消息流程探索
  • 基于python的task--时间片轮询
  • 为了结合后端而学习前端的学习日志——【黑洞光标特效】
  • VMware-centOS7安装redis分布式集群
  • 《Java高级编程:从原理到实战 - 进阶知识篇五》
  • 统计学中的p值是什么?怎么使用?
  • Ray开源程序 是用于扩展 AI 和 Python 应用程序的统一框架。Ray 由一个核心分布式运行时和一组用于简化 ML 计算的 AI 库组成
  • 初识 iOS 开发中的证书固定
  • flink常用算子整理
  • QT | 常用控件
  • 个人文章不设置vip
  • MySQL复合查询全解析:从基础到多表关联与高级技巧
  • 【Hive入门】Hive与Spark SQL深度集成:Metastore与Catalog兼容性全景解析
  • 视频转GIF
  • 网狐系列三网通新钻石娱乐源码全评:结构拆解、三端实测与本地部署问题记录
  • ResNet改进(37):DenseBlock模块实现
  • 游戏引擎学习第257天:处理一些 Win32 相关的问题
  • 【Python】一直没搞懂迭代器是什么。。
  • 【Linux】SELinux 的基本操作与防火墙的管理