当前位置: 首页 > news >正文

小白的进阶之路系列之十二----人工智能从初步到精通pytorch综合运用的讲解第五部分

在本笔记本中,我们将针对Fashion-MNIST数据集训练LeNet-5的变体。Fashion-MNIST是一组描绘各种服装的图像瓦片,有十个类别标签表明所描绘的服装类型。

# PyTorch model and training necessities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim# Image datasets and image manipulation
import torchvision
import torchvision.transforms as transforms# Image display
import matplotlib.pyplot as plt
import numpy as np# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter# In case you are using an environment that has TensorFlow installed,
# such as Google Colab, uncomment the following code to avoid
# a bug with saving embeddings to your TensorBoard directory# import tensorflow as tf
# import tensorboard as tb
# tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

在TensorBoard中显示图像

让我们首先将数据集中的样本图像添加到TensorBoard:

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):if one_channel:img = img.mean(dim=0)img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()if one_channel:plt.imshow(npimg, cmap="Greys")else:plt.imshow(np.transpose(npimg, (1, 2, 0)))if __name__ == '__main__':transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# Store separate training and validations splits in ./datatraining_set = torchvision.datasets.FashionMNIST('./data',download=True,train=True,transform=transform)validation_set = torchvision.datasets.FashionMNIST('./data',download=True,train=False,transform=transform)training_loader = torch.utils.data.DataLoader(training_set,batch_size=4,shuffle=True,num_workers=2)validation_loader = torch.utils.data.DataLoader(validation_set,batch_size=4,shuffle=False,num_workers=2)# Class labelsclasses = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')# Extract a batch of 4 imagesdataiter = iter(training_loader)images, labels = next(dataiter)# Create a grid from the images and show themimg_grid = torchvision.utils.make_grid(images)matplotlib_imshow(img_grid, one_channel=True)plt.show()

输出为:

在这里插入图片描述

上面,我们使用TorchVision和Matplotlib创建了一个小批量输入数据的视觉网格。下面,我们在SummaryWriter上使用add_image()调用来记录TensorBoard使用的图像,并且我们还调用flush())来确保它立即写入磁盘。

    # Default log_dir argument is "runs" - but it's good to be specific# torch.utils.tensorboard.SummaryWriter is imported abovewriter = SummaryWriter('runs/fashion_mnist_experiment_1')# Write image data to TensorBoard log dirwriter.add_image('Four Fashion-MNIST Images', img_grid)writer.flush()# To view, start TensorBoard on the command line with:#   tensorboard --logdir=runs# ...and open a browser tab to http://localhost:6006/

如果您在命令行启动TensorBoard并在新的浏览器选项卡中打开它(通常在localhost:6006),您应该在IMAGES选项卡下看到图像网格。

绘制标量以可视化训练

TensorBoard对于跟踪您的训练进度和效果非常有用。下面,我们将运行一个训练循环,跟踪一些指标,并保存数据供TensorBoard使用。

让我们定义一个模型来对图像块进行分类,以及一个用于训练的优化器和损失函数:

    class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(sel
http://www.xdnf.cn/news/778483.html

相关文章:

  • 网络安全问题及对策研究
  • Java面试八股--08-数据结构和算法篇
  • JavaWeb是什么?总结一下JavaWeb的体系
  • MQTTX连接阿里云的物联网配置
  • Linux 下 ChromeDriver 安装
  • 70道Hive高频题整理(附答案背诵版)
  • Express教程【006】:使用Express写接口
  • “草台班子”的成长路径分析
  • 基于InternLM的情感调节大师FunGPT
  • Cilium动手实验室: 精通之旅---1.Getting Started with Cilium
  • 深度学习学习率调度器指南:PyTorch 四大 scheduler 对决
  • # 将本地UI生成器从VLLM迁移到DeepSeek API的完整指南
  • iOS 应用如何防止源码与资源被轻易还原?多维度混淆策略与实战工具盘点(含 Ipa Guard)
  • 深入浅出:Oracle 数据库 SQL 执行计划查看详解(1)——基础概念与查看方式
  • 蛋白质结构预测软件openfold介绍
  • 【请关注】MySQL 中常见的加锁方式及各类锁常见问题及对应的解决方法
  • macos常见且应该避免被覆盖的系统环境变量(避免用 USERNAME 作为你的自定义变量名)
  • 数据结构:递归:自然数之和
  • MYSQL 高级 SQL 技巧
  • 虚拟线程与消息队列:Spring Boot 3.5 中异步架构的演进与选择
  • 从零打造AI面试系统全栈开发
  • 字节新出的MCP应用DeepSearch,有点意思。
  • 基于大模型的短暂性脑缺血发作(TIA)全流程预测与干预系统技术方案
  • forEach不能用return中断循环,还是会走循环外的逻辑
  • idea不识别lombok---实体类报没有getter方法
  • 【计算机网络】第七章 运输层
  • 阿里云无影云桌面深度测评
  • GLIDE论文阅读笔记与DDPM(Diffusion model)的原理推导
  • 调用.net DLL让CANoe自动识别串口号
  • 【 java 集合知识 第一篇 】