当前位置: 首页 > ds >正文

深度学习3.5 图像分类数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

代码执行流程图

下载FashionMNIST数据集
定义标签转换函数
构建数据加载器
可视化第一批次图像
配置批量加载参数
测试数据加载速度
动态调整图像尺寸
验证调整后的数据形状

3.5.1 读取数据集

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

下载并加载FashionMNIST数据集
‌关键参数‌:
transform=trans:将图像转换为张量(形状 [1, 28, 28],值域 [0,1])。
download=True:若本地无数据则自动下载。
数据集结构‌:
训练集:60,000 张 28x28 灰度图像。
测试集:10,000 张 28x28 灰度图像。

def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

标签映射
将数字标签(0-9)转换为可读的文本标签(如 0 → ‘t-shirt’)。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

输入 imgs 可以是张量或PIL图像。
squeeze():移除单通道维度(1x28x28 → 28x28),否则 imshow 可能报错。
cmap=‘gray’:确保灰度图正确显示(默认可能为彩色)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

‌输出‌:显示 2行x9列 的图像网格,标题为对应的文本标签。
X.reshape(18, 28, 28):调整形状以匹配 imshow 的输入要求(原始形状为 18x1x28x28)。

在这里插入图片描述

3.5.2 读取小批量

batch_size = 256def get_dataloader_workers():return 4  # 根据CPU核心数调整(通常设为4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

shuffle=True:打乱训练数据顺序,避免模型记忆批次。
num_workers=4:启用4个进程并行加载数据,加速数据读取。

timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加载时间:{timer.stop():.2f} sec')

‘2.30 sec’

3.5.3 整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必须在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

‌功能扩展‌:支持调整图像尺寸(如 resize=64 将图像缩放为 64x64)。
‌预处理顺序‌:
Resize(若指定)
ToTensor(转为张量并归一化)

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形状: {X.shape}, 数据类型: {X.dtype}')  # 输出如 torch.Size([32,1,64,64])print(f'y形状: {y.shape}, 数据类型: {y.dtype}')  # 输出如 torch.int64break

X形状: torch.Size([32, 1, 64, 64]), 数据类型: torch.float32
y形状: torch.Size([32]), 数据类型: torch.int64

X.shape = [batch_size, channels, height, width]
y 为标签张量,形状 [batch_size]

http://www.xdnf.cn/news/1035.html

相关文章:

  • YOLO11改进 | 特征融合Neck篇之Lowlevel Feature Alignment机制:多尺度检测的革新性突破
  • C 语言开发问题:使用 <assert.h> 时,定义的 #define NDEBUG 不生效
  • vin码识别技术-车辆vin识别代码-Java接口集成
  • 《理解C++宏:从#define到条件编译》
  • 【工具】VS Code/Cursor 编辑器状态栏颜色自定义指南
  • 装饰模式:动态扩展对象功能的优雅设计
  • AI Agent开发第35课-揭秘RAG系统的致命漏洞与防御策略
  • 极刻AI搜v1.0 问一次问题 AI工具一起答
  • 城市客运安全员证适用岗位及要求
  • QtCreator的设计器、预览功能能看到程序图标,编译运行后图标消失
  • 关于金碟云星空批号问题
  • 自动化测试
  • Psychology 101 期末测验(附答案)
  • Ubuntu 系统下安装和使用性能分析工具 perf
  • HarmonyOS-ArkUI: animateTo 显式动画
  • Git SSH 密钥多个 Git 来源
  • 承兑汇票文字录入解决方案-承兑汇票识别接口-C++集成方式
  • SQL优化
  • 安卓逆向工程:从APK到内核的层级技术解析
  • 聚客AI万字解密AI-Agent大模型智能体:从架构设计到工业落地的全栈指南
  • 算法题(130):激光炸弹
  • 力扣刷题Day 23:最长连续序列(128)
  • Azkaban集群搭建
  • 基于Python的图片/签名转CAD小工具开发方案
  • 13.电阻在EMC设计中的妙用
  • 黑苹果win10和macOS双系统
  • C++ 的史诗级进化:从C++98到C++20
  • MySQL 触发器
  • 三轴云台之激光测距技术篇
  • 软件工程师中级考试-上午知识点总结(上)