当前位置: 首页 > java >正文

PyTorch数据处理工具箱(可视化工具)

可视化工具

Tensorboard是Google TensorFlow的可视化工具,它可以记录训练数据、评估数据、
网络结构、图像等,并且可以在web上展示,对于观察神经网络训练的过程非常有帮助。
PyTorch可以采用tensorboard_logger、visdom等可视化工具,但这些方法比较复杂或不够
友好。为解决这一问题,人们推出了可用于PyTorch可视化的新的更强大的工具——
tensorboardX。

tensorboardX简介

tensorboardX功能很强大,支持scalar、image、figure、histogram、audio、text、
graph、onnx_graph、embedding、pr_curve and videosummaries等可视化方式。

安装也比较方便,先安装tensorflow(CPU或GPU版),然后安装tensorboardX,在命
令行运行以下命令即可。

pip install tensorboardX

使用tensorboardX的一般步骤如下所示。
1)导入tensorboardX,实例化SummaryWriter类,指明记录日志路径等信息。

from tensorboardX import SummaryWriter
#实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。
writer = SummaryWriter(log_dir='logs')
#调用实例
writer.add_xxx()
#关闭writer
writer.close()

【说明】
①如果是Windows环境,log_dir注意路径解析,如:

writer = SummaryWriter(log_dir=r'D:\myboard\test\logs')

②SummaryWriter的格式为:

SummaryWriter(log_dir=None, comment='', **kwargs)
#其中comment在文件命名加上comment后缀

③如果不写log_dir,系统将在当前目录创建一个runs的目录。
2)调用相应的API接口,接口一般格式为:

add_xxx(tag-name, object, iteration-number)
#即add_xxx(标签,记录的对象,迭代次数)

3)启动tensorboard服务:
cd到logs目录所在的同级目录,在命令行输入如下命令,logdir等式右边可以是相对路
径或绝对路径。

tensorboard --logdir=logs --port 6006
#如果是Windows环境,要注意路径解析,如
#tensorboard --logdir=r'D:\myboard\test\logs' --port 6006

4)web展示。
在浏览器输入:

http://服务器IP或名称:6006 #如果是本机,服务器名称可以使用localhost

便可看到logs目录保存的各种图形,图4-4为示例图。

image
鼠标在图形上移动,还可以看到对应位置具体数据。
有关tensorboardX的更多内容,大家可参考其官
网:https://github.com/lanpa/tensorboardX。

用tensorboardX可视化神经网络

4.4.1节我们介绍了tensorboardX的主要内容,为帮助大家更好地理解,本节我们将介
绍几个实例。实例内容涉及如何使用tensorboardX可视化神经网络模型、可视化损失值、
图像等。

(1)导入需要的模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tensorboardX import SummaryWriter

(2)构建神经网络

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.bn = nn.BatchNorm2d(20)
def forward(self, x):
x = F.max_pool2d(self.conv1(x), 2)
x = F.relu(x) + F.relu(-x)
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = self.bn(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
x = F.softmax(x, dim=1)
return x

(3)把模型保存为graph

#定义输入
input = torch.rand(32, 1, 28, 28)
#实例化神经网络
model = Net()
#将model保存为graph
with SummaryWriter(log_dir='logs',comment='Net') as w:
w.add_graph(model, (input, ))

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tensorboardX import SummaryWriterclass Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.conv2_drop=nn.Dropout2d()self.fc1=nn.Linear(320,50)self.fc2=nn.Linear(50,10)self.bn=nn.BatchNorm2d(20)def forward(self,x):x=F.max_pool2d(self.conv1(x),2)x=F.relu(x)+F.relu(-x)x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x=self.bn(x)x=x.view(-1,320)x=F.relu(self.fc1(x))x=F.dropout(x,training=self.training)x=self.fc2(x)x=F.softmax(x,dim=1)return x#定义输入
input=torch.rand(32,1,28,28)
#实例化神经网络
model=Net()
#将model保存为graph
with SummaryWriter(log_dir='logs',comment='Net') as w:w.add_graph(model,(input,))

打开浏览器,结果如图4-5所示。

tensorboardx可视化计算图

用tensorboardX可视化损失值

可视化损失值,需要使用add_scalar函数,这里利用一层全连接神经网络,训练一元
二次函数的参数。

dtype = torch.FloatTensor
writer = SummaryWriter(log_dir='logs',comment='Linear')
np.random.seed(100)
x_train = np.linspace(-1, 1, 100).reshape(100,1)
y_train = 3*np.power(x_train, 2) +2+ 0.2*np.random.rand(x_train.size).reshape(100,1)
model = nn.Linear(input_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoches):
inputs = torch.from_numpy(x_train).type(dtype)
targets = torch.from_numpy(y_train).type(dtype)
output = model(inputs)
loss = criterion(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存loss的数据与epoch数值
writer.add_scalar('训练损失值', loss, epoch)

用tensorboardX可视化特征图

利用tensorboardX对特征图进行可视化,不同卷积层的特征图的抽取程度是不一样
的。
x从cifair10数据集获取,具体请参考第6章pytorch-06-02.ipynb。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import os# 永久解决方案:设置环境变量在程序最开始
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = self.fc2(x)return xdef visualize_feature_maps():# 初始化net = Net()x = torch.randn(2, 1, 28, 28)  # 使用更合理的输入范围x = (x - x.min()) / (x.max() - x.min())  # 归一化到[0,1]# TensorBoard设置writer = SummaryWriter(log_dir='logs/feature_maps_clean')# 可视化输入writer.add_images('input', x, dataformats='NCHW')# 注册hook函数def conv_hook(module, inp, out):out = out.detach()[:4]  # 只取前4个样本for i in range(min(10, out.size(1))):  # 每个层最多显示10个通道writer.add_images(f'{module.__class__.__name__}_channel_{i}',out[:, i:i + 1],dataformats='NCHW')hooks = []for name, layer in net.named_modules():if isinstance(layer, nn.Conv2d):hooks.append(layer.register_forward_hook(conv_hook))# 前向传播with torch.no_grad():net.eval()output = net(x)# 清理hookfor hook in hooks:hook.remove()writer.close()print("可视化完成!请运行 tensorboard --logdir=logs 查看")if __name__ == '__main__':visualize_feature_maps()
http://www.xdnf.cn/news/18379.html

相关文章:

  • 大模型0基础开发入门与实践:第11章 进阶:LangChain与外部工具调用
  • Building Systems with the ChatGPT API 使用 ChatGPT API 搭建系统(第四章学习笔记及总结)
  • Eino 框架组件协作指南 - 智能图书馆建设手册
  • RAG学习(四)——使用混合检索进行检索优化
  • 机器学习4
  • 自己动手,在Mac开发机上利用ollama部署一款轻量级的大模型Phi-3:mini
  • Python Excel 通用筛选函数
  • 麒麟系统播放图片 速度比较
  • Python工程师进阶学习道路分析
  • 【Django:基础知识】
  • 数据结构-ArrayList
  • Redis实战-基于Session实现分布式登录
  • PyTorch API 1
  • PyTorch API 5
  • 372. 超级次方
  • IIS访问报错:HTTP 错误 500.19 - Internal Server Error
  • Spring Retry实战指南_让你的应用更具韧性
  • 区块链技术:重塑未来互联网的伟大动力
  • Python Day32 JavaScript 数组与对象核心知识点整理
  • 源码编译部署 LAMP 架构详细步骤说明
  • Java设计模式-命令模式
  • python的校园顺路代送系统
  • Day 40:训练和测试的规范写法
  • Flink实现Exactly-Once语义的完整技术分解
  • 利用无事务方式插入数据库解决并发插入问题(最小主键id思路)
  • idea进阶技能掌握, 自带HTTP测试工具HTTP client使用方法详解,完全可替代PostMan
  • 暖哇科技AI调查智能体上线,引领保险调查风控智能化升级
  • 【数据结构】排序算法全解析:概念与接口
  • RK android14 Setting一级菜单IR遥控器无法聚焦问题解决方法
  • Apache ShenYu和Nacos之间的通信原理