当前位置: 首页 > news >正文

使用 PyTorch 和 SwanLab 实时可视化模型训练

以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(experiment_name="MNIST_CNN",description="Simple CNN on MNIST with SwanLab monitoring",config={"batch_size": 64,"epochs": 10,"learning_rate": 0.01,"model": "CNN"}
)# 1. 数据准备
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)# 2. 定义 CNN 模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout = nn.Dropout(0.25)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = self.conv2(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = self.dropout(x)x = torch.flatten(x, 1)x = self.fc1(x)x = nn.functional.relu(x)x = self.dropout(x)x = self.fc2(x)return nn.functional.log_softmax(x, dim=1)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=swanlab.config.learning_rate)# 3. 训练循环
def train(epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = nn.functional.nll_loss(output, target)loss.backward()optimizer.step()# 实时记录每个batch的损失if batch_idx % 100 == 0:swanlab.log({"train_loss": loss.item()}, step=epoch * len(train_loader) + batch_idx)# 打印日志到控制台print(f"Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}")# 4. 测试函数
def test(epoch):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)accuracy = 100. * correct / len(test_loader.dataset)# 记录epoch级别的指标swanlab.log({"test_loss": test_loss,"accuracy": accuracy,"epoch": epoch})print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%\n")# 5. 执行训练
for epoch in range(1, swanlab.config.epochs + 1):train(epoch)test(epoch)print("训练完成!请在 https://swanlab.cn 查看可视化结果")

关键说明:

  1. SwanLab 初始化

    swanlab.init() # 创建实验并设置跟踪参数
    
  2. 实时日志记录

    swanlab.log({"train_loss": loss.item()}) # 记录每个batch的损失
    
  3. 指标可视化

    swanlab.log({"accuracy": accuracy, "test_loss": test_loss}) # 记录测试指标
    

使用步骤:

  1. 安装依赖:
pip install torch torchvision swanlab
  1. 运行脚本:
python mnist_example.py
  1. 查看结果:
    • 终端会自动打印监控链接(如:SwanLab Experiment: https://swanlab.cn/[username]/MNIST_CNN/runs/[run_id]
    • 或在 SwanLab 官网 登录查看

仪表盘功能:

  1. 实时监控

    • 训练损失曲线(每100个batch更新)
    • 测试精度/损失曲线(每个epoch更新)
  2. 实验管理

    • 记录所有超参数(batch_size, lr等)
    • 保存实验配置和系统环境
    • 对比多次运行结果
  3. 自动分析

    • 训练过程动态可视化
    • 指标变化趋势分析
    • 性能指标汇总统计

通过这个示例,你可以实时:

  • 监控训练损失下降趋势
  • 观察模型在验证集的性能变化
  • 分析不同超参数对结果的影响
  • 比较多次实验的结果差异

SwanLab 会自动保存所有实验数据,即使训练中断也能恢复可视化结果。

http://www.xdnf.cn/news/995725.html

相关文章:

  • Python使用总结之Linux部署python3环境
  • 【测试开发】数据类型篇-列表推导式和字典推导式
  • Vue3+TypeScript实现责任链模式
  • XML 注入与修复
  • 接口测试不再难:智能体自动生成 Postman 集合
  • Apache 反向代理Unity服务器
  • Golang启用.exe文件无法正常运行
  • NGINX 四层 SSL/TLS 支持ngx_stream_ssl_module
  • vue3集成高德地图绘制轨迹地图
  • 鸿蒙 UI 开发基础语法与组件复用全解析:从装饰器到工程化实践指南
  • uni-app 小程序 Cannot read property ‘addEventListener‘ of undefined, mounted hook
  • 一.干货干货!!!SpringAI入门到实战-小试牛刀
  • 山东大学《Web数据管理》期末复习宝典【万字解析!】
  • Mac 系统 Node.js 安装与版本管理指南
  • 使用Gitlab CI/CD结合docker容器实现自动化部署
  • React 集中状态管理方案
  • CentOS变Ubuntu后后端程序SO库报错,解决方案+原理分析!
  • .NET 中的异步编程模型
  • [电赛]MSPM0G3507学习笔记(二) GPIO:led与按键(流水灯、呼吸灯,短按长按与双击,ui预览)
  • 基于OpenCV和深度学习实现图像风格迁移
  • VR 地震安全演练:“透视” 地震,筑牢企业安全新护盾​
  • 16层混压PCB的精密重构:高频基板局部化的黄金法则
  • 【Go-补充】实现动态数组:深入理解 slice 与自定义实现
  • 机器学习 [白板推导](六)[核方法、指数族分布]
  • 【WebSocket】WebSocket架构重构:从分散管理到统一连接的实战经验
  • 新零售视域下实体与虚拟店融合的技术逻辑与商业模式创新——基于开源AI智能名片与链动2+1模式的S2B2C生态构建
  • C#事件基础模型代码
  • 【技术追踪】MMFusion:用于食管癌淋巴结转移诊断的多模态扩散模型(MICCAI-2024)
  • Linux部署bmc TrueSight 监控agent步骤
  • Java学习笔记之:初识nginx