Python训练营---Day38
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
CIFAR 数据集是机器学习领域中常用的图像分类基准数据集,由加拿大安大略省高级研究所(CIFAR)收集整理。它包含多个版本,最常用的是CIFAR-10和CIFAR-100。
1. CIFAR-10
- 数据规模:
- 训练集:50,000 张图片。
- 测试集:10,000 张图片。
- 类别数量:10 个类别,每个类别 6000 张图片。
- 类别名称:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车
- 图片特点:
- 尺寸:32×32 像素的彩色图像(RGB 三通道)。
- 复杂度:包含自然场景中的常见物体,背景复杂。
2. CIFAR-100
- 数据规模:
- 训练集:50,000 张图片。
- 测试集:10,000 张图片。
- 类别数量:100 个类别,每个类别 500 张训练图片和 100 张测试图片。
- 类别结构:
- 100 个类别分为 20 个超类(Superclass),每个超类包含 5 个子类。
- 例如:超类 “鱼” 包含 “水族馆鱼”“比目鱼”“射线”“鲨鱼”“鳟鱼”。
- 图片特点:
- 尺寸和格式与 CIFAR-10 相同(32×32 RGB)。
- 类别更细粒度,分类难度更高。
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 定义预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #RGB 三个通道的均值均设为 0.5,标准差均设为 0.5
])# 加载CIFAR-10数据集,如果没有会自动下载
train_dataset = datasets.CIFAR10(root='./cifar_data',train=True,download=True,transform=transform
)test_dataset = datasets.CIFAR10(root='./cifar_data',train=False,transform=transform
)# 随机选择一张图片
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
#为什么train_dataset[sample_idx]可以获取到图片和标签,
#是因为 datasets.MNIST这个类继承了torch.utils.data.Dataset类,这个类中有一个方法__getitem__,
#这个方法会返回一个tuple,tuple中第一个元素是图片,第二个元素是标签。# print(f"Image shape: {image.shape}, Label: {label}") # 可视化原始图像(需要反归一化)
def imshow_cifar(img, label=None, class_names=None):"""可视化CIFAR-10数据集的图片参数:- img: 标准化后的图像张量 [3, 32, 32]- label: 图像对应的标签(可选)- class_names: 类别名称列表(可选)"""# 默认CIFAR-10类别名称if class_names is None:class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']img = img * 0.5 + 0.5 # 反标准化# 将张量转换为NumPy数组并调整维度顺序 [C,H,W] → [H,W,C]npimg = img.numpy()npimg = np.transpose(npimg, (1, 2, 0))# 显示图像plt.imshow(npimg)# 添加标签if label is not None:plt.title(f"Label: {class_names[label]}")plt.axis('off') # 移除坐标轴plt.tight_layout()plt.show()imshow_cifar(image, label)