当前位置: 首页 > news >正文

深度学习3.5图像分类数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

代码执行流程图

下载FashionMNIST数据集
定义标签转换函数
构建数据加载器
可视化第一批次图像
配置批量加载参数
测试数据加载速度
动态调整图像尺寸
验证调整后的数据形状

3.5.1 读取数据集

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

下载并加载FashionMNIST数据集
‌关键参数‌:
transform=trans:将图像转换为张量(形状 [1, 28, 28],值域 [0,1])。
download=True:若本地无数据则自动下载。
数据集结构‌:
训练集:60,000 张 28x28 灰度图像。
测试集:10,000 张 28x28 灰度图像。

def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

标签映射
将数字标签(0-9)转换为可读的文本标签(如 0 → ‘t-shirt’)。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

输入 imgs 可以是张量或PIL图像。
squeeze():移除单通道维度(1x28x28 → 28x28),否则 imshow 可能报错。
cmap=‘gray’:确保灰度图正确显示(默认可能为彩色)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

‌输出‌:显示 2行x9列 的图像网格,标题为对应的文本标签。
X.reshape(18, 28, 28):调整形状以匹配 imshow 的输入要求(原始形状为 18x1x28x28)。

在这里插入图片描述

3.5.2 读取小批量

batch_size = 256def get_dataloader_workers():return 4  # 根据CPU核心数调整(通常设为4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

shuffle=True:打乱训练数据顺序,避免模型记忆批次。
num_workers=4:启用4个进程并行加载数据,加速数据读取。

timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加载时间:{timer.stop():.2f} sec')

‘2.30 sec’

3.5.3 整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必须在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

‌功能扩展‌:支持调整图像尺寸(如 resize=64 将图像缩放为 64x64)。
‌预处理顺序‌:
Resize(若指定)
ToTensor(转为张量并归一化)

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形状: {X.shape}, 数据类型: {X.dtype}')  # 输出如 torch.Size([32,1,64,64])print(f'y形状: {y.shape}, 数据类型: {y.dtype}')  # 输出如 torch.int64break

X形状: torch.Size([32, 1, 64, 64]), 数据类型: torch.float32
y形状: torch.Size([32]), 数据类型: torch.int64

X.shape = [batch_size, channels, height, width]
y 为标签张量,形状 [batch_size]

http://www.xdnf.cn/news/60643.html

相关文章:

  • elastic/go-elasticsearch与olivere/elastic
  • 乐家桌面安卓版2025下载-乐家桌面软件纯净版安装分享码大全
  • 【scikit-learn基础】--『监督学习』之 均值聚类
  • GPT,Genini, Claude Llama, DeepSeek,Qwen,Grok,选对LLM大模型真的可以事半功倍!
  • 发布事件和Insert数据库先后顺序
  • GeoJSON 格式详解与使用指南
  • Macbook IntelliJ IDEA终端无法运行mvn命令
  • 【2025面试Java常问八股之redis】zset数据结构的实现,跳表和B+树的对比
  • 1.Vue3 - 创建Vue3工程
  • JavaEE--2.多线程
  • RHCE 练习二:通过 ssh 实现两台主机免密登录以及 nginx 服务通过多 IP 区分多网站
  • 【基础算法】二分算法详解
  • 科大讯飞Q1营收46.6亿同比增长27.7%,扣非净利同比增长48.3%
  • [c语言日寄]免费文档生成器——Doxygen在c语言程序中的使用
  • uniapp-商城-31-shop页面中的 我的订单
  • 【大语言模型DeepSeek+ChatGPT+python】最新AI-Python机器学习与深度学习技术在植被参数反演中的核心技术应用
  • idea使用docker插件一键部署项目
  • Time to event :Kaplan-Meier曲线、Log Rank检验与Shiny R
  • Oracle EBS R12.2 安装 -- Step by Step
  • 利用Qt创建一个模拟问答系统
  • Oracle expdp的 EXCLUDE 参数详解
  • 【橘子大模型】Tools/Function call
  • 【MySQL】库的操作
  • MCU开发学习记录10 - 高级定时器学习与实践(HAL库)—PWM互补输出、死区控制、刹车控制 - STM32CubeMX
  • 邀请函 | 「软件定义汽车 同星定义软件」 TOSUN用户日2025·杭州站
  • SQL 中 ROLLUP 的使用方法
  • 系统安全及应用
  • Spark-SQL与Hive集成及数据分析实践
  • 【C++游戏引擎开发】第18篇:视锥体裁剪与光源剔除
  • XMLXXE 安全无回显方案OOB 盲注DTD 外部实体黑白盒挖掘