当前位置: 首页 > web >正文

Dataset和Dataloader

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

总结

维度DatasetDataLoader
核心职责定义“数据是什么”和“如何获取单个样本”定义“如何批量加载数据”和“加载策略”
核心方法__getitem__(获取单个样本)、__len__(样本总数)无自定义方法,通过参数控制加载逻辑
预处理位置__getitem__中通过transform执行预处理无预处理逻辑,依赖Dataset返回的预处理后数据
并行处理无(仅单样本处理)支持多进程加载(num_workers>0
典型参数root(数据路径)、transform(预处理)batch_sizeshufflenum_workers

核心结论

  • Dataset:定义数据的内容和格式(即“如何获取单个样本”),包括:

    • 数据存储路径/来源(如文件路径、数据库查询)。
    • 原始数据的读取方式(如图像解码为PIL对象、文本读取为字符串)。
    • 样本的预处理逻辑(如裁剪、翻转、归一化等,通常通过transform参数实现)。
    • 返回值格式(如(image_tensor, label))。
  • DataLoader:定义数据的加载方式和批量处理逻辑(即“如何高效批量获取数据”),包括:

    • 批量大小(batch_size)。
    • 是否打乱数据顺序(shuffle)。

 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

看看cifar数据库,将代码中数据集名称换为CIFAR10或CIFAR100即可。

 

@浙大疏锦行

http://www.xdnf.cn/news/9246.html

相关文章:

  • VR三维数字空间还原
  • 大模型(4)——Agent(基于大型语言模型的智能代理)
  • 计算机网络基础知识
  • 7000字基于 SpringBoot 的 Cosplay 文化展示与交流社区系统设计与实现
  • 批量文件重命名工具
  • Web安全测试-文件上传绕过-DVWA
  • 【机器学习基础】机器学习入门核心算法:K-近邻算法(K-Nearest Neighbors, KNN)
  • 高效多线程图像处理实战
  • pycharm 新UI 固定菜单栏 pycharm2025 中文版
  • 小样本分类新突破:QPT技术详解
  • Mac M1 安装 ffmpeg
  • winsock对话设计框架
  • 大咖课 | 后期-文本分析
  • 新编辑器编写指南--给自己的备忘
  • 【请关注】VC++ MFC常见异常问题及处理方法
  • 如何使用PHP创建一个安全的用户注册表单,包含输入验证、数据过滤和结果反馈教程。
  • 第三十三天打卡
  • Windows安装Docker部署dify,接入阿里云api-key进行rag测试
  • 新消息!阿里云ACP大模型认证有变化!
  • https下git拉取gitlab仓库源码
  • tmux 入门实用指南(面向远程 Linux 开发者)
  • 测试报告里都包含哪些内容?
  • 使用pnpm、vite搭建Phaserjs的开发环境
  • 常见的网络设备
  • 【iOS(swift)笔记-11】App版本升级时本地数据库sqlite更新逻辑
  • 二十九、面向对象底层逻辑-SpringMVC九大组件之MultipartResolver接口设计
  • leetcode每日一题 -- 2131.连接两字母单词得到的最长回文串
  • taro + vue3 实现小程序sse长连接实时对话
  • el-tree拖拽事件,限制同级拖拽,获取拖拽后节点的前后节点,同级拖拽合并父节点name且子节点加入目标节点里
  • 让 Deepseek 写一个尺码计算器