当前位置: 首页 > ai >正文

打卡day51

选择数据集

使用CIFAR-10,包含10类、60,000张32x32彩色图像(50k训练,10k测试)。

import torch
from torchvision import datasets, transforms# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
])# 下载并加载训练集和测试集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform
)# 创建 DataLoader
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False
)

构建基础CNN模型

import torch
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.fc1 = nn.Linear(32 * 8 * 8, 256)self.fc2 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = self.pool(self.relu(self.conv1(x)))  # 16x16x16x = self.pool(self.relu(self.conv2(x)))  # 32x8x8x = x.view(-1, 32 * 8 * 8)x = self.relu(self.fc1(x))x = self.fc2(x)return x

实现CBAM模块

class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.mlp = nn.Sequential(nn.Linear(in_channels, in_channels // ratio),nn.ReLU(),nn.Linear(in_channels // ratio, in_channels))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.mlp(self.avg_pool(x).squeeze(-1).squeeze(-1))max_out = self.mlp(self.max_pool(x).squeeze(-1).squeeze(-1))channel_weights = self.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1)return x * channel_weightsclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)spatial = torch.cat([avg_out, max_out], dim=1)spatial_weights = self.sigmoid(self.conv(spatial))return x * spatial_weightsclass CBAM(nn.Module):def __init__(self, in_channels):super(CBAM, self).__init__()self.ca = ChannelAttention(in_channels)self.sa = SpatialAttention()def forward(self, x):x = self.ca(x)x = self.sa(x)return x

将CBAM集成到CNN中

class CNNWithCBAM(nn.Module):def __init__(self):super(CNNWithCBAM, self).__init__()self.conv1 = nn.Conv2d(3, 16, 3, padding=1)self.cbam1 = CBAM(16)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 3, padding=1)self.cbam2 = CBAM(32)self.fc1 = nn.Linear(32 * 8 * 8, 256)self.fc2 = nn.Linear(256, 10)self.relu = nn.ReLU()def forward(self, x):x = self.relu(self.conv1(x))x = self.cbam1(x)  # 添加CBAMx = self.pool(x)x = self.relu(self.conv2(x))x = self.cbam2(x)  # 添加CBAMx = self.pool(x)x = x.view(-1, 32 * 8 * 8)x = self.relu(self.fc1(x))x = self.fc2(x)return x

训练与比较

设备设置(GPU/CPU)

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm  # 可视化训练进度device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

训练函数

def train_model(model, train_loader, test_loader, epochs=50, lr=1e-3):model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)best_acc = 0.0history = {'train_loss': [], 'test_acc': []}for epoch in range(epochs):# 训练阶段model.train()running_loss = 0.0for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * images.size(0)epoch_loss = running_loss / len(train_loader.dataset)history['train_loss'].append(epoch_loss)# 测试阶段model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()epoch_acc = 100 * correct / totalhistory['test_acc'].append(epoch_acc)# 保存最佳模型if epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), f"{model.__class__.__name__}_best.pth")print(f"Epoch {epoch+1}: Loss={epoch_loss:.4f}, Test Acc={epoch_acc:.2f}%")print(f"Best Test Accuracy: {best_acc:.2f}%")return history

评估函数

def evaluate_model(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return 100 * correct / total

开始训练

# 初始化两个模型
model_baseline = SimpleCNN()
model_cbam = CNNWithCBAM()# 训练基线模型(SimpleCNN)
print("Training Baseline Model (SimpleCNN)...")
history_baseline = train_model(model_baseline, train_loader, test_loader)# 训练CBAM增强模型(CNNWithCBAM)
print("\nTraining CBAM-enhanced Model (CNNWithCBAM)...")
history_cbam = train_model(model_cbam, train_loader, test_loader)# 加载最佳模型并最终评估
model_baseline.load_state_dict(torch.load("SimpleCNN_best.pth"))
model_cbam.load_state_dict(torch.load("CNNWithCBAM_best.pth"))final_acc_baseline = evaluate_model(model_baseline, test_loader)
final_acc_cbam = evaluate_model(model_cbam, test_loader)print(f"\nFinal Results:")
print(f"Baseline Model Test Accuracy: {final_acc_baseline:.2f}%")
print(f"CBAM-enhanced Model Test Accuracy: {final_acc_cbam:.2f}%")
import matplotlib.pyplot as pltplt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history_baseline['train_loss'], label='Baseline')
plt.plot(history_cbam['train_loss'], label='CBAM')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history_baseline['test_acc'], label='Baseline')
plt.plot(history_cbam['test_acc'], label='CBAM')
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.legend()
plt.show()

输出
在这里插入图片描述

http://www.xdnf.cn/news/13520.html

相关文章:

  • CMake安装教程
  • 2025GEO供应商排名深度解析:源易信息构建AI生态优势
  • 新德通:光通信领域的硬核力量,引领高速互联新时代
  • Appium + Node.js 测试全流程
  • 最接近的三数之和
  • Java 基础知识填空题(共 10 题)
  • 6.ref创建对象类型的响应式数据
  • FPGA实现VESA DSC编码功能
  • 【游戏项目】大型项目Git分支策略与开发流程设计构想
  • 无人机智能运行系统技术解析
  • 为进行性核上性麻痹患者定制:饮食健康指南
  • 全球首个体重管理AI大模型“减单”发布,学AI大模型来近屿智能
  • CMake指令: add_sub_directory以及工作流程
  • 速盾:高防CDN可以加速数据库吗?
  • ​​5G通信设备线路板打样:猎板PCB如何攻克高速数据传输技术瓶颈​​
  • bat 批处理查看文件年龄
  • C51 KEIL使用使用问题处理
  • Java异步编程深度解析:从基础到复杂场景的难题拆解
  • K8S中应用无法获取用户真实ip问题排查
  • 数据链抗干扰
  • DNS小结
  • 避免在 iOS 和 Android 的 WebView 中长按出现复制框等默认行为
  • 手机解压 7z 文件全攻略
  • 【全志V821_FoxPi】2-2 切换为spi nand方案启动
  • HTML5 浮动
  • 统计可分解整数的数量
  • leetcode1584. 连接所有点的最小费用-medium
  • 2025低空经济区的安全与应急控制专题研讨会(SECOLZ 2025)
  • DDoS攻防实战:从应急脚本到AI云防护系统
  • 2025年智慧城市与管理工程国际会议(ICSCME 2025)