Python训练营打卡Day38
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
(原文链接:https://blog.csdn.net/qq_40755283/article/details/125209463)
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np# 设置随机种子以确保结果可复现
torch.manual_seed(42)# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(), # 将图像转换为 Tensor
])# 加载 CIFAR-10 训练数据集
trainset = torchvision.datasets.CIFAR10(root='./data', # 数据存储路径train=True, # 是否为训练集download=True, # 是否下载数据transform=transform # 应用数据预处理
)# 创建数据加载器
trainloader = torch.utils.data.DataLoader(trainset, # 数据集batch_size=1, # 每次加载 1 张图片shuffle=True # 打乱数据顺序
)# 获取一个批次的数据(包含 1 张图片)
dataiter = iter(trainloader)
images, labels = next(dataiter)# CIFAR-10 数据集的类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 定义一个函数来显示图像
def imshow(img, title):img = img.numpy() # 将 Tensor 转换为 NumPy 数组img = np.transpose(img, (1, 2, 0)) # 调整维度顺序 [C, H, W] -> [H, W, C]plt.imshow(img)plt.title(title)plt.axis('off')plt.show()# 显示图像及其标签
imshow(images[0], f'Label: {classes[labels[0]]}')# 打印数据集信息
print(f"CIFAR-10 数据集包含 {len(trainset)} 张训练图片")
print(f"每张图片的尺寸为 {images[0].shape[1:]}")
print(f"数据集类别: {classes}")