当前位置: 首页 > news >正文

day52 ResNet18 CBAM

 

在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升了模型的性能,还对深度学习中的预训练和微调有了更深刻的理解。

一、背景知识

ResNet18 是一种经典的卷积神经网络架构,广泛应用于图像分类任务。CBAM 是一种注意力机制模块,能够同时关注特征图的通道和空间维度,提升模型对关键特征的关注能力。将 CBAM 模块插入 ResNet18 中,可以增强模型的特征表达能力。

二、研究方法

1. CBAM 模块的插入位置

   - CBAM 模块被插入到 ResNet18 的每个残差块(BasicBlock)之后。这样可以在每个特征提取阶段都引入注意力机制,让模型在提取特征的同时学会关注重要的特征。

   - CBAM 模块的初始状态接近“直通”,即在训练初期,CBAM 模块对特征图的影响较小,不会破坏预训练模型的权重。

2. 预训练策略

   - 阶段 1(Epoch 1-5):仅解冻分类头(fc)和所有 CBAM 模块,冻结 ResNet18 的主干卷积层。目标是让模型快速学习新任务的分类边界,同时让 CBAM 模块找到初步的关注点。学习率设置为 1e-3。

   - 阶段 2(Epoch 6-20):解冻高层卷积层(layer3, layer4),保持低层卷积层(layer1, layer2)冻结。目标是让模型的高层特征提取能力适应新任务的抽象概念。学习率设置为 1e-4。

   - 阶段 3(Epoch 21-50):解冻所有层,进行端到端微调。目标是让模型的底层特征也与新任务对齐,提升整体性能。学习率设置为 1e-5。

三、实验过程

1. 数据预处理

   - 使用 CIFAR-10 数据集,包含 10 个类别的 60,000 张 32x32 的彩色图像。

   - 数据增强包括随机裁剪、水平翻转、颜色抖动等。

2. 模型定义

   - 定义了 ResNet18_CBAM 模型,继承自 PyTorch 的 nn.Module。

   - 在每个残差块后插入 CBAM 模块,调整通道数和空间维度的注意力权重。

3. 训练过程

   - 使用 Adam 优化器,动态调整学习率。

   - 每个阶段的训练过程都有详细的日志输出,包括每个 batch 的损失和每个 epoch 的训练准确率和测试准确率。

四、关键结论

1. 训练过程中的损失和准确率变化

   - 在阶段 1,模型的训练准确率从 37.31% 提升到 49.86%,测试准确率从 47.48% 提升到 54.98%。

   - 在阶段 2,模型的训练准确率从 61.34% 提升到 86.26%,测试准确率从 71.71% 提升到 85.99%。

   - 在阶段 3,模型的训练准确率从 88.75% 提升到 95.15%,测试准确率从 87.58% 提升到 90.15%。

2. 最终性能

   - 经过 50 个 epoch 的训练,模型的最终测试准确率达到了 90.15%。这表明 CBAM 模块显著提升了模型的性能,尤其是在高层特征提取和全局微调阶段。

五、代码实现

以下是 ResNet18_CBAM 模型的定义和训练过程的代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 定义 CBAM 模块
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.shape
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * attention

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        pool_out = torch.cat([avg_out, max_out], dim=1)
        attention = self.conv(pool_out)
        return x * self.sigmoid(attention)

class CBAM(nn.Module):
    def __init__(self, in_channels, ratio=16, kernel_size=7):
        super().__init__()
        self.channel_attn = ChannelAttention(in_channels, ratio)
        self.spatial_attn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel_attn(x)
        x = self.spatial_attn(x)
        return x

# 定义 ResNet18_CBAM 模型
class ResNet18_CBAM(nn.Module):
    def __init__(self, num_classes=10, pretrained=True, cbam_ratio=16, cbam_kernel=7):
        super().__init__()
        self.backbone = models.resnet18(pretrained=pretrained)
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()
        self.cbam_layer1 = CBAM(in_channels=64, ratio=cbam_ratio, kernel_size=cbam_kernel)
        self.cbam_layer2 = CBAM(in_channels=128, ratio=cbam_ratio, kernel_size=cbam_kernel)
        self.cbam_layer3 = CBAM(in_channels=256, ratio=cbam_ratio, kernel_size=cbam_kernel)
        self.cbam_layer4 = CBAM(in_channels=512, ratio=cbam_ratio, kernel_size=cbam_kernel)
        self.backbone.fc = nn.Linear(in_features=512, out_features=num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.layer1(x)
        x = self.cbam_layer1(x)
        x = self.backbone.layer2(x)
        x = self.cbam_layer2(x)
        x = self.backbone.layer3(x)
        x = self.cbam_layer3(x)
        x = self.backbone.layer4(x)
        x = self.cbam_layer4(x)
        x = self.backbone.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.backbone.fc(x)
        return x

# 数据预处理
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 训练函数
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        if (batch_idx + 1) % 100 == 0:
            print(f'Batch: {batch_idx+1}/{len(train_loader)} | 单Batch损失: {loss.item():.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# 测试函数
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    epoch_loss = test_loss / len(test_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# 主函数
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    model = ResNet18_CBAM(num_classes=10, pretrained=True).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    epochs = 50
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train_loss, train_acc = train(model, device, train_loader, optimizer, criterion)
        test_loss, test_acc = test(model, device, test_loader, criterion)
        scheduler.step()
        print(f'Epoch {epoch}/{epochs} 完成 | 耗时: {time.time() - epoch_start_time:.2f}s | 训练准确率: {train_acc:.2f}% | 测试准确率: {test_acc:.2f}%')

    torch.save(model.state_dict(), 'resnet18_cbam_finetuned.pth')
    print("模型已保存为: resnet18_cbam_finetuned.pth")

if __name__ == "__main__":
    main()
@浙大疏锦行

 

http://www.xdnf.cn/news/964513.html

相关文章:

  • Canfestival的移植思想
  • EndNote 21完整安装指南:从零开始的详细步骤(附EndNote下载安装包)
  • HTML 文本省略号
  • HTML 标签 综合案例
  • 在鸿蒙HarmonyOS 5中HarmonyOS应用开发实现QQ音乐风格的播放功能
  • CppCon 2015 学习:Improving the future<T> with monads
  • MinHook 对.NET底层的 SendMessage 拦截真实案例反思
  • PHP和Node.js哪个更爽?
  • 【论文阅读】多任务学习起源类论文《Multi-Task Feature Learning》
  • MyBatis注解开发的劣势与不足
  • LeetCode--27.移除元素
  • Leetcode 3578. Count Partitions With Max-Min Difference at Most K
  • HTML 列表、表格、表单
  • Docker-containerd-CRI-CRI-O-OCI-runc
  • 【kafka】Golang实现分布式Masscan任务调度系统
  • Python 自动化临时邮箱工具,轻松接收验证码,支持调用和交互模式(支持谷歌gmail/googlemail)
  • 【C++】26. 哈希扩展1—— 位图
  • 【PhysUnits】17.5 实现常量除法(div.rs)
  • Linux上并行打包压缩工具
  • Cryosparc: Local Motion Correction注意输出颗粒尺寸
  • 基于大模型的输尿管下段结石诊疗全流程预测与方案研究
  • 多场景 OkHttpClient 管理器 - Android 网络通信解决方案
  • 【AI study】ESMFold安装
  • Ribbon负载均衡实战指南:7种策略选择与生产避坑
  • 深度学习核心概念:优化器、模型可解释性与欠拟合
  • 【无标题新手学习期权从买入看涨期权开始】
  • OpenCV 图像像素值统计
  • Python入门手册:常用的Python标准库
  • C++初阶-list的模拟实现(难度较高)
  • C++学习-入门到精通【17】自定义的模板化数据结构