当前位置: 首页 > backend >正文

PyTorch入门-torchvision

torchvision

torchvision 是 PyTorch 的一个重要扩展库,专门针对计算机视觉任务设计。它提供了丰富的预训练模型、常用数据集、图像变换工具和计算机视觉组件,大大简化了视觉相关深度学习项目的开发流程。

我们可以在Pytorch的官网找到torchvision的文档

在这里插入图片描述

文档中提供了很多数据集

在这里插入图片描述

这里以CIFAR10为例,它是图像分类常用的数据集

CIFAR-10 数据集由 60,000 张 32x32 像素的彩色图像组成,分为 10 个类别,每个类别有 6,000 张图像。其中 50,000 张是训练图像,10,000 张是测试图像。

数据集分为五个训练批次和一个测试批次,每个批次包含 10,000 张图像。测试批次包含每个类别中随机选择的 1,000 张图像。训练批次包含剩余的图像,顺序随机,但某些训练批次可能包含一个类别的更多图像。所有训练批次加起来正好包含每个类别的 5,000 张图像。

在这里插入图片描述
在这里插入图片描述

除了数据集之外,还提供了模型torchvision.models 模块包含了一系列预训练的深度学习模型,广泛应用于图像分类、目标检测、语义分割等任务。

我们可以通过代码下载数据集

import torchvisiontrans_set = torchvision.datasets.CIFAR10(root = "./dataset",train= True,download= True)
test_set = torchvision.datasets.CIFAR10(root = "./dataset",train= False,download= True)

参数列表

  1. root (str):
    • 数据集存储的路径,数据将下载到此目录下。
  2. train (bool, optional):
    • 如果为 True,则加载训练集;如果为 False,则加载测试集。默认值为 True
  3. transform (callable, optional):
    • 一个函数/转换,用于对图像进行预处理,比如数据增强、归一化等。
  4. target_transform (callable, optional):
    • 一个函数/转换,用于对目标(标签)进行处理。
  5. download (bool, optional):
    • 如果为 True,则从网上下载数据集(如果在指定路径中不存在)。默认值为 False

下载完成后可以看到项目目录中的数据集
在这里插入图片描述

我们可以打印一下print("训练集数量:", len(trans_set)) 查看训练集数量

在这里插入图片描述

完整代码如下,可以看到我们的第一个图片是cat

import torchvision# 下载并加载CIFAR10训练数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, download= True)# 下载并加载CIFAR10测试数据集
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False, download= True)# 获取测试集的第一个样本和对应的标签
img, target = test_set[0]
# 显示测试集中的类别标签
print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 显示样本的图像数据
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x1BF002C7710>
# 显示样本的标签
print(target) # 3
# 根据标签索引对应的类别名称
print(test_set.classes[target]) # cat
# 显示图像
# 在这里使用PIL库的Image模块的show方法,直接在屏幕上展示图像
img.show()

这个数据集的图片都比较小(32x32 像素),放大以后虽然这个看起来并不像猫,反而像老鼠,但是它就是cat

在这里插入图片描述

上面我们得到的数据类型是PIL,我们需要转为tensor类型,我们只需要新增一个Compose然后修改dataset代码

# 定义数据集转换
dataset_transform = torchvision.transforms.Compose([# 将图像数据转换为 Tensortorchvision.transforms.ToTensor()    # 还可以对 Tensor 进行归一化,参数分别表示均值和标准差#torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 下载并加载CIFAR10训练数据集
# 参数:
#   root: 指定数据集的保存路径
#   train: 指示是训练数据集(True)还是测试数据集(False)
#   transform: 对数据集中的每个图像应用的转换操作
#   download: 如果数据集不存在于指定路径且设置为True,则会自动下载数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, transform= dataset_transform,download= True)
# 下载并加载CIFAR10测试数据集,参数同上
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False,transform= dataset_transform, download= True)

然后我们执行之后,控制台会打印图片,此时是我们想要的tensor数据类型(tensor类型图片不能使用show()

在这里插入图片描述

我们就可以显示在tensorBoard中

writer = SummaryWriter("pics")
# 获取测试集的10个样本和对应的标签
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)writer.close()

仔细看,能够依稀辨认出第十张图片是车
在这里插入图片描述

在这里插入图片描述

http://www.xdnf.cn/news/9266.html

相关文章:

  • 零基础远程连接课题组Linux服务器,安装anaconda,配置python环境(换源),在服务器上运行python代码【3/3 适合小白,步骤详细!!!】
  • 【R语言编程绘图-折线图】
  • Redis C语言连接教程
  • Linux 环境下C、C++、Go语言编译环境搭建秘籍
  • 常见编码小结
  • 常见JDK安装配置
  • springboot 笔记
  • Redis核心数据结构操作指南:字符串、哈希、列表详解
  • 【K8S】K8S基础概念
  • Java spingboot项目 在docker运行,需要含GDAL的JDK
  • 飞牛fnNAS手机相册备份及AI搜图
  • 博图SCL基础知识-表达式及赋值运算
  • 甲醇 燃料 不也有碳排放吗?【AI回答版】
  • 得物Java开发面试题及参考答案(下)
  • Linux操作系统概述
  • 【Canvas与日月星辰】烈日当空
  • 关于git的使用
  • 【漏洞与预防】Microsoft Windows 文件资源管理器欺骗漏洞预防
  • 【免费】【无需登录/关注】Base64 图片转换工具网页
  • 【Java】DelayQueue
  • LangGraph(七)——Workflows
  • 基于物联网(IoT)的电动汽车(EVs)智能诊断
  • Java组合、聚合与关联:核心区别解析
  • AWS WebRTC:获取信令服务节点和ICE服务节点
  • 深度解读 Qwen3 大语言模型的关键技术
  • 【Elasticsearch】ingest对于update操作起作用吗?
  • Android15 Camera Hal设置logLevel控制日志输出
  • vue2使用el-tree实现两棵树间节点的拖拽复制
  • LeetCode 2894.分类求和并作差:数学O(1)一行解决
  • Java提取markdown中的表格