python打卡day49@浙大疏锦行
知识点回顾:
- 通道注意力模块复习
- 空间注意力模块
- CBAM的定义
作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程
一、通道注意力模块复习 & CBAM实现
import torch
import torch.nn as nnclass CBAM(nn.Module):def __init__(self, channels, reduction=16):super().__init__()# 通道注意力self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(channels, channels//reduction, 1),nn.ReLU(),nn.Conv2d(channels//reduction, channels, 1),nn.Sigmoid())# 空间注意力self.spatial_attention = nn.Sequential(nn.Conv2d(2, 1, 7, padding=3),nn.Sigmoid())def forward(self, x):# 通道注意力ca = self.channel_attention(x)x = x * ca# 空间注意力sa = torch.cat([torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)], dim=1)sa = self.spatial_attention(sa)return x * sa# 在ResNet中插入CBAM
model = resnet18(pretrained=True)
model.layer1[0].add_module("cbam", CBAM(64))
二、参数统计方法
from torchsummary import summary# 检查模型参数
summary(model.to(Config.DEVICE), (3, 224, 224))
三、TensorBoard监控增强
# 在训练循环中添加
writer.add_scalar('Loss/train', running_loss/100, epoch*len(trainloader)+i)
writer.add_scalar('Accuracy/test', accuracy, epoch)# 启动TensorBoard
# 在命令行中运行:tensorboard --logdir=runs
关键点说明:
1. CBAM模块包含通道和空间注意力分支
2. 使用summary函数可显示参数量
3. TensorBoard记录需保持writer实例的持续使用