当前位置: 首页 > news >正文

pytorch学习1(DataSet+Transforms+TensorBoard)

        在深度学习项目中,高效的数据处理和模型监控是成功的关键。PyTorch 提供了强大的工具集来简化这些任务。本文将探讨 PyTorch 中的三个核心组件:Dataset(数据集处理)、TensorBoard(可视化监控)和 Transforms(数据预处理),并通过实际代码示例展示如何将它们结合起来使用。

  Dataset

Dataset 是 PyTorch 中用于表示数据集的抽象类,它提供了标准化的方式来组织和访问数据。通过实现自定义 Dataset 类,我们可以轻松地将任何格式的数据集成到 PyTorch 训练流程中。

# 导入必要的库
from torch.utils.data import Dataset  # PyTorch的数据集基类
from PIL import Image  # Python图像处理库
import os  # 操作系统接口库

# 自定义数据集类,继承PyTorch的Dataset基类
class MyData(Dataset):
"""
自定义数据集类,用于加载图像数据和对应标签
参数:
root_dir: 数据集根目录
label_dir: 标签子目录名称(也作为类别标签)
"""

    def __init__(self, root_dir, label_dir):
"""
初始化函数
Args:
root_dir: 数据集根目录路径
label_dir: 标签/类别目录名称
"""
self.root_dir = root_dir  # 存储根目录路径
self.label_dir = label_dir  # 存储标签目录名称(也作为类别标签)
self.path = os.path.join(self.root_dir, self.label_dir)  # 拼接完整路径
self.img_path = os.listdir(self.path)  # 获取该类别下所有图像文件名列表

    def __getitem__(self, idx):
"""
获取单个样本
Args:
idx: 样本索引
Returns:
img: PIL Image对象
label: 字符串形式的类别标签
"""
img_name = self.img_path[idx]  # 获取指定索引的图像文件名
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 拼接完整图像路径
img = Image.open(img_item_path)  # 使用PIL加载图像
label = self.label_dir  # 使用目录名作为标签
return img, label  # 返回图像和标签

    def __len__(self):
"""
返回数据集大小
Returns:
该类别下的图像数量
"""
return len(self.img_path)  # 返回图像列表长度


# 数据集路径设置
root_dir = "train"  # 训练集根目录
ants_label_dir = "ants_label"  # 蚂蚁类别目录名
bees_label_dir = "bees_label"  # 蜜蜂类别目录名

# 创建数据集实例
ants_dataset = MyData(root_dir, ants_label_dir)  # 蚂蚁数据集
bees_dataset = MyData(root_dir, bees_label_dir)  # 蜜蜂数据集


Transforms

        ,

        原始图像数据通常不能直接输入神经网络,Transforms作为数据预处理的强大工具,可以帮助我们进行一系列预处理:

  • 尺寸标准化

  • 数值归一化

  • 数据增强

# 内置模块

from torchvision import transforms

# 基础转换
basic_transform = transforms.Compose([
transforms.Resize(256),          # 调整大小
transforms.CenterCrop(224),      # 中心裁剪
transforms.ToTensor(),           # 转为张量
transforms.Normalize(            # 标准化
mean=[0.485, 0.456, 0.406], 
std=[0.229, 0.224, 0.225]
)
])

# 数据增强
augmentation_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),  # 随机水平翻转
transforms.RandomRotation(15),      # 随机旋转
transforms.ColorJitter(             # 颜色抖动
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])


TensorBoard

它是一个训练可视化工具,下面举一个非常简单的小例子:

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")
for i in range(100):writer.add_scalar("y=x", i, i)
writer.close()    


总结

        通过合理使用Dataset、Transforms和TensorBoard这三个PyTorch核心组件,我们可以构建高效、可维护的深度学习管道。Dataset提供了灵活的数据加载方式,Transforms确保了数据的一致性和多样性,而TensorBoard则让训练过程变得透明可控。掌握这三者的使用是成为PyTorch高手的重要一步。

http://www.xdnf.cn/news/1123381.html

相关文章:

  • LeetCode 692题解 | 前K个高频单词
  • 工业软件加密锁复制:一场技术与安全的博弈
  • Lovable - AI 驱动的全栈应用开发平台
  • PyTorch张量(Tensor)创建的方式汇总详解和代码示例
  • [笔记] 动态 SQL 查询技术解析:构建灵活高效的企业级数据访问层
  • Linux:1_Linux下基本指令
  • TCP心跳机制详解
  • 使用axios向服务器请求信息并渲染页面
  • 如何在服务器上运行一个github项目
  • K8S的平台核心架构思想[面向抽象编程]
  • docker私有仓库
  • Ai问答之空间站星等
  • 【科研绘图系列】R语言绘制世界地图
  • C++ 中常见的字符串定义方式及其用法
  • 使用Java完成下面项目
  • 解决chrome v2 版本插件不支持
  • uni-app在安卓设备上获取 (WIFI 【和】以太网) ip 和 MAC
  • C语言-数据输入与输出
  • java学习 day4 分布式锁
  • 【Learning Notes】 Derak Callan‘s Business English P38~40
  • 【【异世界历险之数据结构世界(二叉树)】】
  • Why C# and .NET are still relevant in 2025
  • 安装Keycloak并启动服务(macOS)
  • 4.2TCP/IP
  • USB读写自动化压力测试
  • 小波变换 | 离散小波变换
  • AI驱动的软件工程(下):AI辅助的质检与交付
  • FreeRTOS之链表操作相关接口
  • 人工智能如何重构能源系统以应对气候变化?
  • 29.安卓逆向2-frida hook技术-逆向os文件(二)IDA工具下载和使用