当前位置: 首页 > java >正文

python学习打卡day38

DAY 38 Dataset和Dataloader类

对应5. 27作业

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。

 导入Dataset类和Dataloader类必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)

其中torchvision是一个计算机视觉的库,它的常见方法如下:

torchvision

  1. datasets       # 视觉数据集(如 MNIST、CIFAR)
  2.  transforms     # 视觉数据预处理(如裁剪、翻转、归一化)
  3. models         # 预训练模型(如 ResNet、YOLO)
  4. utils          # 视觉工具函数(如目标检测后处理)
  5. io             # 图像/视频 IO 操作

 1.其中的transforms模块提供了一系列常用的图像预处理操作:

# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])

2.MNIST数据集

# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

随机取出一张图片(包括图片和标签) 

import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

可视化取出的图像 

# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)

当然也可以用相同的思路取出两张:

import matplotlib.pyplot as plt
import torch# 随机选择两张图片的索引
sample_idx_1 = torch.randint(0, len(train_dataset), size=(1,)).item()
sample_idx_2 = torch.randint(0, len(train_dataset), size=(1,)).item()# 获取图片和标签
image_1, label_1 = train_dataset[sample_idx_1]
image_2, label_2 = train_dataset[sample_idx_2]# 定义一个函数来反归一化并显示图像
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray')# 创建一个包含两个子图的画布
fig, axes = plt.subplots(1, 2, figsize=(10, 5))# 显示第一张图片
plt.sca(axes[0])
imshow(image_1)
axes[0].set_title(f'Label: {label_1}')
axes[0].axis('off')# 显示第二张图片
plt.sca(axes[1])
imshow(image_2)
axes[1].set_title(f'Label: {label_2}')
axes[1].axis('off')plt.show()

我们是如何通过dataset类取出图像的呢?? 

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

- __len__():返回数据集的样本总数。

- __getitem__(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

# 示例代码
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

再介绍一下Dataloader类

# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

作业: 

维度CIFAR 数据集MNIST 手写数据集
创建机构 / 背景由加拿大先进研究所(CIFAR)开发,用于计算机视觉研究由纽约大学柯朗数学科学研究所开发,用于手写数字识别研究
数据类型自然物体彩色图像(如动物、交通工具等)手写数字灰度图像(0-9)
图像分辨率32×32 像素(RGB 三通道,彩色图像)28×28 像素(单通道,灰度图像)
数据集规模- 总样本数:60,000 张
- 训练集:50,000 张
- 测试集:10,000 张
- 总样本数:70,000 张
- 训练集:60,000 张
- 测试集:10,000 张
类别数量- CIFAR-10:10 个大类
- CIFAR-100:100 个细分类别(20 个超类)
10 个类别(数字 0-9)
任务难度- 图像分辨率低但包含复杂背景和类内差异
- CIFAR-100 因类别多、区分度小,难度更高
图像背景简单,数字形态相对固定,难度较低
典型应用场景图像分类、目标识别、深度学习算法基准测试(如 CNN 优化)手写数字识别、基础算法验证(如神经网络入门案例)
数据预处理需进行色彩归一化、数据增强(如裁剪、翻转)等处理通常仅需灰度归一化和简单降噪处理
模型性能基准- CIFAR-10 顶尖模型准确率:~97%
- CIFAR-100 顶尖模型准确率:~87%
顶尖模型准确率:~99.7%(如 CNN)
相似点- 均为图像分类领域经典基准数据集
- 均包含训练集和测试集,结构标准化
- 广泛用于算法教学、研究和性能对比
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理,将像素值从[0,1]缩放到[-1,1]
])

 

# 加载训练集
train_dataset = datasets.CIFAR10(root='./data',  # 数据存放路径train=True,  # 是否为训练集download=True,  # 如果数据不存在,是否自动下载transform=transform  # 数据预处理
)# 加载测试集
test_dataset = datasets.CIFAR10(root='./data',  # 数据存放路径train=False,  # 是否为测试集transform=transform  # 数据预处理
)
import matplotlib.pyplot as plt
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img, title=None):img = img / 2 + 0.5  # 反归一化:将[-1,1]范围转回[0,1]npimg = img.numpy()plt.figure(figsize=(4, 4))plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 调整通道顺序:从[C,H,W]到[H,W,C]if title:plt.title(title)plt.axis('off')plt.show()print(f"Label: {label} ({classes[label]})")
imshow(image, f"Label: {classes[label]}")

@浙大疏锦行

http://www.xdnf.cn/news/9550.html

相关文章:

  • 截图后怎么快速粘贴到notability?
  • day22-定时任务故障案例
  • 秒杀系统—2.第一版初步实现的技术文档
  • 医院闭环系统业务介绍
  • Linux基础 -- 设备树引脚复用之`/omit-if-no-ref/` 用法解析
  • 8.7 基于EAP-AKA的订阅转移
  • Springboot 集成 TDengine3.0版本
  • git stash 的使用
  • qt ubuntu 20.04 交叉编译
  • python实战:在Linux服务器上使用LibreOffice命令行批量接受Word文档的所有修订
  • MCP 与 AI 模型的用户隐私保护——如何让人工智能更懂“界限感”?
  • Python-114:字符串字符类型排序问题
  • HBO Max 中国大陆订阅与使用终极指南(2025 最新)
  • LangChain4j(17)——MCP客户端
  • 在PHP编程中包(Package)和库(Library)怎么区分?
  • 企业级AI开启落地战,得场景者得天下
  • LeeCode 94. 二叉树的中序遍历
  • YARN架构解析:大数据资源管理核心
  • 【MYSQL】mysql单表亿级数据查询优化处理
  • 2021年认证杯SPSSPRO杯数学建模D题(第二阶段)停车的策略全过程文档及程序
  • 探寻黄金奶源带,悠纯乳业打造西北乳业新标杆
  • Spring AI框架快速入门
  • day12 leetcode-hot100-20(矩阵3)
  • 【Linux】网络(上)
  • Vue开发系列——如何使用Vue
  • 图像卷积OpenCV C/C++ 核心操作
  • 【DB2】ERRORCODE=-4499, SQLSTATE=08001
  • 【C++基础知识】匿名命名空间
  • mysql prepare statement
  • 如何查询服务器的端口号