当前位置: 首页 > ds >正文

Python训练营打卡Day43

DAY 43 复习日

作业:
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

@浙大疏锦行


选择 Dogs vs Cats 数据集(Kaggle经典二分类问题)

完整代码实现

1. 设置环境并加载数据

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import cv2
from PIL import Image# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集
dataset = datasets.ImageFolder('./dogs-vs-cats/train', transform=transform)# 划分训练集和验证集(80%训练,20%验证)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 查看类别
print(f"Classes: {dataset.classes}")

2. 定义CNN模型(使用预训练的ResNet18)

class DogCatCNN(nn.Module):def __init__(self):super(DogCatCNN, self).__init__()# 使用预训练的ResNet18self.resnet = models.resnet18(pretrained=True)# 冻结所有卷积层参数(可选)for param in self.resnet.parameters():param.requires_grad = False# 替换最后的全连接层(适应我们的二分类问题)num_features = self.resnet.fc.in_featuresself.resnet.fc = nn.Sequential(nn.Linear(num_features, 256),nn.ReLU(),nn.Dropout(0.5),nn.Linear(256, 2)  # 2 classes: dog, cat)def forward(self, x):return self.resnet(x)model = DogCatCNN().to(device)

3. 训练函数

def train_model(model, criterion, optimizer, num_epochs=10):best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch+1}/{num_epochs}')print('-' * 10)# 训练阶段model.train()running_loss = 0.0running_corrects = 0for inputs, labels in train_loader:inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = running_corrects.double() / len(train_loader.dataset)print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 验证阶段model.eval()val_loss = 0.0val_corrects = 0with torch.no_grad():for inputs, labels in val_loader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)val_loss += loss.item() * inputs.size(0)val_corrects += torch.sum(preds == labels.data)val_loss = val_loss / len(val_loader.dataset)val_acc = val_corrects.double() / len(val_loader.dataset)print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}\n')# 保存最佳模型if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), 'best_model.pth')print(f'Best val Acc: {best_acc:.4f}')return model# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
model = train_model(model, criterion, optimizer, num_epochs=10)

4. Grad-CAM可视化实现

class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子target_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = outputdef save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0]def forward(self, x):return self.model(x)def __call__(self, x, class_idx=None):# 前向传播output = self.forward(x)if class_idx is None:class_idx = torch.argmax(output, dim=1).item()# 反向传播self.model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][class_idx] = 1output.backward(gradient=one_hot, retain_graph=True)# 计算权重pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])# 加权特征图activations = self.activations[0]for i in range(activations.size(0)):activations[i, :, :] *= pooled_gradients[i]# 生成热图heatmap = torch.mean(activations, dim=0).detach().cpu().numpy()heatmap = np.maximum(heatmap, 0)heatmap /= np.max(heatmap)return heatmapdef visualize_gradcam(model, image_path, target_layer):# 加载并预处理图像img = Image.open(image_path).convert('RGB')img_tensor = transform(img).unsqueeze(0).to(device)# 获取预测类别model.eval()with torch.no_grad():output = model(img_tensor)pred_class = torch.argmax(output, dim=1).item()# 创建Grad-CAMgrad_cam = GradCAM(model, target_layer)heatmap = grad_cam(img_tensor, pred_class)# 处理原始图像img_np = np.array(img.resize((224, 224)))img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)# 生成热图叠加heatmap = cv2.resize(heatmap, (img_np.shape[1], img_np.shape[0]))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)superimposed_img = heatmap * 0.4 + img_np * 0.6superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)# 显示结果plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))plt.title(f'Original (Pred: {dataset.classes[pred_class]})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB))plt.title('Grad-CAM')plt.axis('off')plt.show()# 选择目标层(ResNet的最后一个卷积层)
target_layer = model.resnet.layer4[-1].conv2# 可视化示例
test_image_path = './dogs-vs-cats/train/dog.100.jpg' 
visualize_gradcam(model, test_image_path, target_layer)
http://www.xdnf.cn/news/11155.html

相关文章:

  • [Win32]画刷、矩形、不规则区域和剪裁
  • SD卡被写保护怎么解除?
  • 全网最详细的网络安全(白帽黑客)入门教程,282g资源无偿分享
  • mysql bulk update_91.一次性处理多条数据的方法:bulk_create,update,delete
  • 如何更改pcAnywhere的默认端口(zz)
  • linux c可变参数va_start、va_end、va_arg、va_list
  • 使用计算机的便利,allshare play:只要几个步骤立即使用这个非常便利的功能!...
  • adobe dreamweaver cs5序列号
  • Ubuntu 9.04安装教程(傻瓜版)
  • struts2通配符使用
  • 如何在VirtualBox虚拟机中安装XP系统? 转
  • Android签名机制及PMS中校验流程(雷惊风)
  • Cadence Allegro PCB设计88问解析(七) 之 Allegro位号反标OrCAD
  • Java版文本编辑器
  • socks5 运行几个小时后 端口10808不通了,ss5服务正常
  • Spring整合Quartz框架实现定时任务跑批(Maven完整版)
  • 超级详细的GitLab安装 与使用 【Gitlab添加组、创建用户和项目、权限管理】_gitlab群组
  • 探秘GPT:开启人工智能语言模型的新纪元
  • Yandex 邮箱添加
  • 麒麟短线王至尊版 软件及指标 应用!
  • Linux系统下安装部署Linux管理面板1panel
  • Win11系统提示找不到olecli32.dll文件的解决办法
  • web.xml中context-param的配置作用
  • dll和so文件区别与构成
  • 世界环保创业基金会简介
  • LDAP 目录服务器的现代化应用
  • 动态实现RelativeLayout,LinearLayout布局
  • Lenovo笔记本新版Veriface Pro(人像识别)软件介绍
  • InstallShield 12 制作安装包
  • anaconda安装gdal、Fiona、shapely、pyproj、geopandas