当前位置: 首页 > web >正文

Python Day38

Task:
1.Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
2.Dataloader类
3.minist手写数据集的了解


1. Dataset 类的 __getitem____len__ 方法

在 PyTorch (或类似深度学习框架) 中,Dataset 是一个抽象基类,用于表示你的数据。它通常用于将原始数据(例如图像文件、文本文件、CSV 数据等)处理成模型可以直接消费的格式。

Dataset 类有两个核心的特殊方法,它们是 Python 的“魔法方法”:

  • __len__(self):

    • 作用: 这个方法必须返回数据集中样本的总数量。
    • 实现: 当你创建一个 Dataset 的子类时,你需要实现它来告诉 PyTorch 这个数据集有多大。
    • 用处: Dataloader 需要知道总长度才能正确地进行批处理、洗牌和分发数据。
    • 示例:
      class MyDataset(Dataset):def __init__(self, data_list):self.data = data_list # 假设data_list是你的数据源def __len__(self):return len(self.data) # 返回数据源的长度def __getitem__(self, idx):# ... 具体实现将在下面说明pass
      
  • __getitem__(self, idx):

    • 作用: 这个方法用于根据给定的索引 idx 返回数据集中的一个样本。
    • 实现: 这是最关键的部分。你需要在其中定义如何加载、预处理(如图像变换、文本编码)并返回一个样本及其对应的标签。
    • 返回类型: 通常,它返回一个元组或字典,其中包含一个数据样本和其对应的标签。例如 (image_tensor, label_tensor)
    • 用处: 当 Dataloader 需要获取一个批次的数据时,它会内部多次调用 __getitem__ 来收集单个样本。
    • 示例:
      import torch
      from torch.utils.data import Datasetclass CustomImageDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transform # 用于图像预处理的转换def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]label = self.labels[idx]# 假设这里是加载图像的逻辑 (实际会用Pillow等库)# 为了演示,我们创建一个虚拟图像tensorimage = torch.randn(3, 224, 224) # 3 channels, 224x224 pixelsif self.transform:image = self.transform(image) # 应用预处理return image, label # 返回图像张量和标签
      

总结 Dataset 的作用和特殊方法:

Dataset 类负责:

  1. 数据抽象: 将原始数据封装成一个可迭代、可索引的对象。
  2. 数据加载: 在 __getitem__ 中处理从文件系统或内存中加载单个数据项的逻辑。
  3. 数据预处理: 在 __getitem__ 中应用必要的预处理步骤(如归一化、裁剪、数据增强)。
  4. 提供索引: __len____getitem__ 使得数据集可以通过索引访问,并知道其总大小。

2. DataLoader

DataLoader 是 PyTorch 中一个非常强大的工具,它建立在 Dataset 之上,负责高效地加载和批处理数据。它的核心功能是:

  • 批处理 (Batching): 将单个样本组合成批次,这是深度学习训练的常用方式,因为它可以提高计算效率,并有助于梯度下降的稳定。
  • 洗牌 (Shuffling): 在每个 epoch 开始时随机打乱数据,以防止模型学习到数据中的顺序模式,并提高模型的泛化能力。
  • 多进程数据加载 (Multiprocessing Data Loading): 可以使用多个工作进程并行加载数据,从而减少数据加载成为训练瓶颈的可能性。
  • 内存固定 (Pin Memory): 可以将张量加载到 CUDA 固定内存中,这可以加快数据传输到 GPU 的速度。

DataLoader 的主要参数:

  • dataset: 必须是 torch.utils.data.Dataset 的实例。这是 DataLoader 从中获取数据的来源。
  • batch_size: 每个批次包含的样本数量。
  • shuffle: 布尔值,如果设置为 True,则在每个 epoch 开始时打乱数据。
  • num_workers: 用于数据加载的子进程数量。设置为 0 意味着数据将在主进程中加载。大于 0 会开启多进程,通常能加快加载速度,但也需要更多内存。
  • drop_last: 布尔值,如果设置为 True,则如果数据集大小不能被 batch_size 整除,则最后一个不完整的批次将被丢弃。
  • collate_fn: 可选参数,一个函数,用于如何将单个样本列表合并成一个批次。默认情况下,它会尝试堆叠张量。如果你有复杂的数据结构(如变长序列),你可能需要自定义这个函数。

DataLoader 的使用:

DataLoader 是一个可迭代对象。你可以直接在 for 循环中使用它来获取批次数据。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np# 假设我们有一个简单的Dataset
class SimpleDataset(Dataset):def __init__(self, num_samples=100):self.data = torch.randn(num_samples, 10) # 100个样本,每个样本10个特征self.labels = torch.randint(0, 2, (num_samples,)) # 100个标签,0或1def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 创建数据集实例
my_dataset = SimpleDataset(num_samples=100)# 创建DataLoader实例
train_loader = DataLoader(dataset=my_dataset,batch_size=16,shuffle=True,num_workers=0) # 简单示例,不使用多进程# 迭代DataLoader获取批次数据
for epoch in range(5): # 假设训练5个epochprint(f"\nEpoch {epoch+1}")for batch_idx, (data, labels) in enumerate(train_loader):print(f"  Batch {batch_idx+1}: data shape = {data.shape}, labels shape = {labels.shape}")# 在这里执行模型的前向传播、计算损失、反向传播等训练步骤if batch_idx >= 2: # 只打印前3个批次,避免输出过多break

DataLoaderDataset 的协作:

  • DataLoader 接收一个 Dataset 对象。
  • DataLoader 需要一个批次数据时,它会:
    1. 如果 shuffle=True,它会首先打乱 Dataset 的索引。
    2. 它会选择 batch_size 个索引。
    3. 对于每个选定的索引,它会调用 Dataset__getitem__(idx) 方法来获取单个样本。
    4. 它将这些单个样本集合起来(默认通过 torch.stacktorch.cat),形成一个批次张量。
    5. 最终将批次张量返回给你的训练循环。

3. MNIST 手写数字数据集的了解

MNIST (Modified National Institute of Standards and Technology) 是一个经典的、广泛使用的计算机视觉数据集,被誉为“深度学习的 Hello World”。

主要特点:

  1. 内容: 包含大量手写数字的灰度图像。
  2. 类别: 10 个类别,对应数字 0 到 9。
  3. 图像大小: 每张图像都是 28x28 像素。
  4. 数据量:
    • 训练集: 60,000 张图像,用于训练模型。
    • 测试集: 10,000 张图像,用于评估模型的性能。
  5. 图像格式: 灰度图像,每个像素的值通常在 0 到 255 之间,表示像素亮度。

MNIST 的重要性:

  • 入门级: 简单且足够小,适合初学者学习深度学习的基本概念和 PyTorch 的使用。
  • 基准: 由于其标准化和广泛使用,它经常作为新算法和模型架构的初步测试基准。
  • 低计算需求: 训练一个在 MNIST 上表现良好的模型通常不需要强大的 GPU,普通 CPU 也能完成。

PyTorch 中使用 MNIST:

PyTorch 的 torchvision 库提供了方便的工具来下载和加载 MNIST 数据集。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 定义数据转换 (Transformations)
# MNIST图像是PIL.Image类型,需要转换为Tensor,并进行归一化。
# 归一化是常用的预处理步骤,将像素值缩放到一个特定范围(例如0到1,或-1到1)。
# 对于MNIST,通常是 (mean=0.1307, std=0.3081),这是根据整个MNIST数据集计算得出的。
transform = transforms.Compose([transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为FloatTensor,并除以255将像素值缩放到0-1transforms.Normalize((0.1307,), (0.3081,)) # 归一化,(mean,) (std,),对于灰度图像是单通道
])# 2. 下载并加载训练数据集
# root: 数据存放的根目录
# train=True: 获取训练集
# download=True: 如果数据不存在,则下载
# transform: 应用上述定义的转换
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 3. 下载并加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 4. 创建 DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) # num_workers可以根据你的CPU核心数调整
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # 测试集通常不打乱# 5. 遍历训练数据 (示例)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")for batch_idx, (data, target) in enumerate(train_loader):print(f"训练批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")# data.shape 会是 [batch_size, 1, 28, 28] (1是通道数,28x28是图像尺寸)# target.shape 会是 [batch_size]break # 只打印第一个批次# 6. 遍历测试数据 (示例)
for batch_idx, (data, target) in enumerate(test_loader):print(f"测试批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")break
http://www.xdnf.cn/news/10097.html

相关文章:

  • 特伦斯 S75 电钢琴:重塑音乐感知,臻享艺术之境
  • ADUM3201ARZ-RL7在混合动力电池监控中的25kV/μs CMTI与系统级ESD防护设计
  • Tornado WebSocket实时聊天实例
  • 58-dify案例分享-用 Dify 工作流 搭建数学错题本,考试错题秒变提分神器-同类型题生成篇
  • PHP学习笔记(十一)
  • 顶会新热门:机器学习可解释性
  • VScode-使用技巧-持续更新
  • 鸿蒙OSUniApp智能商品展示实战:打造高性能的动态排序系统#三方框架 #Uniapp
  • Kotlin JVM 注解详解
  • MySQL之数据库的内嵌函数和联合查询
  • Dify理论+部署+实战
  • 利用计算机模拟和玉米壳废料开发新型抗病毒药物合成方法
  • 详解Seata的核心组件TC、TM、RM
  • YOLOv8分割onnx实战及tensorRT部署
  • 黑森林实验室 FLUX.1Kontext:革新图像修改的 AI 力量
  • React 事件处理与合成事件机制揭秘
  • 计算机视觉入门:OpenCV与YOLO目标检测
  • 优化版本,增加3D 视觉 查看前面的记录
  • MySQL 的 super_read_only 和 read_only 参数
  • 板凳-------Mysql cookbook学习 (九)
  • MQTT的Thingsboards的使用
  • WebFuture:设置不自动删除操作日志
  • Celery简介
  • 全面解析:npm 命令、package.json 结构与 Vite 详解
  • 基于LBS的上门代厨APP开发全流程解析
  • 鸿蒙OSUniApp复杂表单与动态验证实践:打造高效的移动端表单解决方案#三方框架 #Uniapp
  • 特伦斯 S75 电钢琴:奏响极致音乐体验的华丽乐章
  • 大话软工笔记—分离之业务与管理
  • Spring Advisor增强规则实现原理介绍
  • 测试工程师学LangChain之promptTemplate 实战笔记