当前位置: 首页 > news >正文

打卡Day45

使用PyTorch在CIFAR10数据集上微调ResNet18,并用TensorBoard监控训练过程

1. 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import os

2. 数据预处理与加载

# 数据增强和归一化(使用ImageNet统计量)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)

3. 模型准备(ResNet18微调)

# 加载预训练模型并修改
model = torchvision.models.resnet18(pretrained=True)# 修改第一层适配32x32输入(原始为224x224)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # 移除初始maxpool# 修改最后的全连接层(CIFAR10有10类)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 移动到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. 训练配置

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 创建TensorBoard writer
writer = SummaryWriter('runs/resnet18_cifar10_finetune')

5. 训练循环(集成TensorBoard日志)

def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录batch级数据if batch_idx % 100 == 0:writer.add_scalar('Training/Loss (batch)', loss.item(), epoch * len(train_loader) + batch_idx)writer.add_scalar('Training/Accuracy (batch)', 100. * correct / total, epoch * len(train_loader) + batch_idx)# 记录epoch级数据avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalwriter.add_scalar('Training/Loss (epoch)', avg_loss, epoch)writer.add_scalar('Training/Accuracy (epoch)', acc, epoch)print(f'Epoch: {epoch} | Train Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_lossdef test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录验证结果avg_loss = test_loss / len(test_loader)acc = 100. * correct / totalwriter.add_scalar('Validation/Loss', avg_loss, epoch)writer.add_scalar('Validation/Accuracy', acc, epoch)# 记录学习率writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)print(f'Test Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_loss# 主训练循环
for epoch in range(100):train_acc, train_loss = train(epoch)test_acc, test_loss = test(epoch)scheduler.step()# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')writer.close()
http://www.xdnf.cn/news/886609.html

相关文章:

  • Redis(02)Win系统如何将Redis配置为开机自启的服务
  • 如何选择专业数据可视化开发工具?为您拆解捷码全功能和落地指南!
  • Android 进程分类
  • 5G 网络中 DRX(非连续接收)技术深度解析
  • java: 找不到符号 符号: 变量 log
  • 【opencv】基础知识到进阶(更新中)
  • Modern C++(三)表达式
  • Kafka深度解析与原理剖析
  • MySQL数据库基础(一)———数据库管理
  • 华为OD最新机试真题-小明减肥-OD统一考试(B卷)
  • python编写赛博朋克风格天气查询程序
  • PyTorch中matmul函数使用详解和示例代码
  • vscode 离线安装第三方库跳转库
  • python3.9带 C++绑定的基础镜像
  • 【深尚想】OPA855QDSGRQ1运算放大器IC德州仪器TI汽车级高速8GHz增益带宽的全面解析
  • 基于ResNet残差网络优化梯度下降算法实现图像分类
  • 编程技能:格式化打印05,格式控制符
  • 人工智能AI在数字化转型有哪些应用?
  • Android设置顶部状态栏透明,以及状态栏字体颜色
  • TDengine 开发指南—— UDF函数
  • 【JeecgBoot AIGC】AI知识库实战应用与搭建
  • 01 Deep learning神经网络的编程基础 二分类--吴恩达
  • Windows应用-GUID工具
  • LFWG2024.08
  • BeeWorks 协同办公能力:局域网内企业级协作的全场景重构
  • 电脑提示dll文件缺失怎么办 dll修复方法
  • 【Elasticsearch】 查询优化方式
  • openvino如何在c++中调用pytorch训练的模型
  • 【Oracle】分区表
  • Maxscript快速入门(四)