打卡Day45
使用PyTorch在CIFAR10数据集上微调ResNet18,并用TensorBoard监控训练过程
1. 环境准备
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import os
2. 数据预处理与加载
# 数据增强和归一化(使用ImageNet统计量)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)
3. 模型准备(ResNet18微调)
# 加载预训练模型并修改
model = torchvision.models.resnet18(pretrained=True)# 修改第一层适配32x32输入(原始为224x224)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity() # 移除初始maxpool# 修改最后的全连接层(CIFAR10有10类)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 移动到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
4. 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 创建TensorBoard writer
writer = SummaryWriter('runs/resnet18_cifar10_finetune')
5. 训练循环(集成TensorBoard日志)
def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录batch级数据if batch_idx % 100 == 0:writer.add_scalar('Training/Loss (batch)', loss.item(), epoch * len(train_loader) + batch_idx)writer.add_scalar('Training/Accuracy (batch)', 100. * correct / total, epoch * len(train_loader) + batch_idx)# 记录epoch级数据avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalwriter.add_scalar('Training/Loss (epoch)', avg_loss, epoch)writer.add_scalar('Training/Accuracy (epoch)', acc, epoch)print(f'Epoch: {epoch} | Train Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_lossdef test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录验证结果avg_loss = test_loss / len(test_loader)acc = 100. * correct / totalwriter.add_scalar('Validation/Loss', avg_loss, epoch)writer.add_scalar('Validation/Accuracy', acc, epoch)# 记录学习率writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)print(f'Test Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_loss# 主训练循环
for epoch in range(100):train_acc, train_loss = train(epoch)test_acc, test_loss = test(epoch)scheduler.step()# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')writer.close()