当前位置: 首页 > news >正文

卷积模型的优化--Dropout、批标准化与学习率衰减

codes

文章目录

  • 卷积模型优化技术:Dropout、批标准化与学习率衰减🔥
    • 0. 首先导入必要的库,数据预处理(跟前面的代码实际上基本一样)
    • 1. Dropout抑制过拟合
      • 1.1 Dropout核心原理
      • 1.2 训练与预测模式
    • 2. 批标准化(Batch Normalization)
      • 2.1 PyTorch中的批标准化层
      • 2.2 BN层的作用
    • 3. 学习速率衰减
      • 3.1 学习率衰减原理
      • 3.2 PyTorch实现
      • 3.2 损失与准确率曲线
    • 5. 总结

卷积模型优化技术:Dropout、批标准化与学习率衰减🔥

0. 首先导入必要的库,数据预处理(跟前面的代码实际上基本一样)

# 首先导入要用到的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import glob # 用于读取文件路径
from torchvision import transforms # 用于数据预处理
from PIL import Image # 用于读取图片
from torch.utils import data # 用于构建数据集imgs=glob.glob(r'D:/my_all_learning/dataset2/dataset2/*.jpg') 
# 上面是读取图片的路径,/*.jpg表示读取所有jpg格式的图片,imgs是一个列表,里面包含了所有图片的路径。species=['cloudy','rain','shine','sunrise'] #4 classes
# 字典推导式获取类别到编号的映射关系
species_to_idx=dict((c,i) for i,c in enumerate(species))
# 字典推导式获取编号到类别的映射关系
idx_to_species=dict((i,c) for i,c in enumerate(species))# 下面提取图片路径列表对应的标签列表
labels=[]
for img in imgs:for i,c in enumerate(species):if c in img: # 判断图片路径是否包含某个种类的名称labels.append(i)# 获取图片路径列表与对应的标签列表后我们就可以着手编写自定义dataset类了
# 首先定义预处理图片的transform
transform=transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])# 然后创建Dataset类
class WT_Dataset(data.Dataset):def __init__(self, imgs_path, labels):self.imgs_path = imgs_pathself.labels = labelsdef __len__(self):return len(self.imgs_path)def __getitem__(self, index):img_path = self.imgs_path[index]label = self.labels[index]pil_img = Image.open(img_path)pil_img = pil_img.convert('RGB') # 转换为RGB格式pil_img = transform(pil_img)return pil_img, label# 创建自定义Dataset类
dataset=WT_Dataset(imgs,labels)# 下面划分数据集和测试集
## 统计数据集的数量
train_count=int(0.8*len(dataset))
test_count=len(dataset)-train_count## 划分
train_dataset,test_dataset=data.random_split(dataset,[train_count,test_count])BATCH_SIZE = 16
# 创建DataLoader
train_dl=data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dl=data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

1. Dropout抑制过拟合

1.1 Dropout核心原理

Dropout通过在神经网络训练过程中随机将一部分神经元的输出设置为0,有效降低过拟合风险。这种方法本质上是对网络结构的随机化,使模型在训练时不会过度依赖某些特定神经元,从而提高模型的泛化能力。

💡 关键理解:神经网络在预测时不应过度依赖特定神经元,即使丢失这些神经元也能从其他特征中学习共同模式

# 在模型中添加Dropout层
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(64*10*10, 1024)def forward(self, x):x = x.view(-1, 64*10*10)x = F.dropout(x)  # 防止过拟合🔥x = F.relu(self.fc1(x))x = F.dropout(x)  # 防止过拟合🔥return x

1.2 训练与预测模式

Dropout层在不同模式下的行为完全不同:

模式行为何时使用
model.train()随机将神经元输出置0训练过程
model.eval()直接输出所有神经元结果预测过程
### 编写训练和测试代码(与前面的相同)
def train(dataloader,model,loss_fn,optimizer):size=len(dataloader.dataset) #获取当前数据集样本总数量num_batches=len(dataloader) #获取当前data loader总批次数# train_loss用于累计所有批次的损失之和, correct用于累计预测正确的样本总数train_loss,correct=0,0'''add is below'''model.train() # 设置模型为训练模式,启用dropout等训练时特有的操作for X,y in dataloader:X,y=X.to(device),y.to(device)# 进行预测,并计算第一个批次的损失pred=model(X)loss=loss_fn(pred,y)# 利用反向传播算法,根据损失优化模型参数optimizer.zero_grad() #先将梯度清零loss.backward() # 损失反向传播,计算模型参数梯度optimizer.step() #根据梯度优化参数with torch.no_grad():# correct用于累计预测正确的样本总数correct+=(pred.argmax(1)==y).type(torch.float).sum().item()# train_loss用于累计所有批次的损失之和train_loss+=loss.item()# train_loss 是所有批次的损失之和,所以计算全部样本的平均损失时需要除以总的批次数train_loss/=num_batches# correct 是预测正确的样本总数,若计算整个apoch总体正确率,需要除以样本总数量correct/=sizereturn train_loss,correct### 测试函数
def test(dataloader,model):size=len(dataloader.dataset)num_batches=len(dataloader)'''add is below'''model.eval() # 设置模型为评估模式,禁用dropout等训练时特有的操作test_loss,correct=0,0with torch.no_grad():for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model(X)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()test_loss/=num_batchescorrect/=sizereturn test_loss,correct

📝 注释:在测试函数中加入model.eval()确保预测时不使用Dropout,输出更加准确

2. 批标准化(Batch Normalization)

2.1 PyTorch中的批标准化层

PyTorch提供三种批标准化层:

层类型适用场景关键参数
nn.BatchNorm1d全连接层/LSTM层后num_features
nn.BatchNorm2d卷积层后num_features
nn.BatchNorm3d视频/3D数据num_features
# 在卷积层后添加BatchNorm层
'''
下面加入BN层
'''
class Net(nn.Module): #nn.Module是所有神经网络模块的基类def __init__(self):super(Net,self).__init__( ) # 调用父类的构造函数,Net是一个卷积神经网络模型,在torch中继承自nn.Module类self.conv1=nn.Conv2d(3,16,3)'''add is below'''self.bn1=nn.BatchNorm2d(16) # 添加BN层self.conv2=nn.Conv2d(16,32,3)self.bn2=nn.BatchNorm2d(32) # 添加BN层self.conv3=nn.Conv2d(32,64,3)self.bn3=nn.BatchNorm2d(64) # 添加BN层self.fc1=nn.Linear(64*10*10,1024)self.fc2=nn.Linear(1024,4) # 4是输出的类别数#  floor((size+2*padding - kernel_size)/stride) + 1def forward(self,x):# 初始尺寸: [batch, 3, 96, 96]x = F.relu(self.conv1(x))  # 卷积1: (96 - 3) + 1 = 94 → [batch, 16, 94, 94]'''add is below'''x=self.bn1(x) # 添加BN层x = F.max_pool2d(x,2)      # 池化1: 94/2 = 47 → [batch, 16, 47, 47]x = F.relu(self.conv2(x))  # 卷积2: (47 - 3) + 1 = 45 → [batch, 32, 45, 45]x=self.bn2(x) # 添加BN层x = F.max_pool2d(x,2)      # 池化2: 45/2 = 22.5 → 取整22 → [batch, 32, 22, 22]x = F.relu(self.conv3(x))  # 卷积3: (22 - 3) + 1 = 20 → [batch, 64, 20, 20]x=self.bn3(x) # 添加BN层x = F.max_pool2d(x,2)      # 池化3: 20/2 = 10 → [batch, 64, 10, 10]x = x.view(-1,64 * 10 * 10)    # 展平: [batch, 64 * 10 * 10] = [batch, 6400]x=F.dropout(x) # 防止过拟合x = F.relu(self.fc1(x))     # 全连接1: [batch, 1024]x=F.dropout(x) # 防止过拟合x = self.fc2(x)             # 全连接2: [batch, 4]return x

2.2 BN层的作用

批标准化通过对输入数据进行标准化处理,加速模型收敛并提高泛化能力,通常添加在卷积层或全连接层之后。

3. 学习速率衰减

3.1 学习率衰减原理

学习率衰减在训练过程中逐渐减小学习率:

  • 🚀 训练初期:较大学习率快速下降
  • 🐢 训练后期:较小学习率避免跳过极值点

3.2 PyTorch实现

使用lr_scheduler.StepLR实现:

from torch.optim import lr_scheduleroptimizer = optim.Adam(model.parameters(), lr=0.0005)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 然后在训练循环中,我们添加一行代码:exp_lr_scheduler.step(),用于更新学习率。
# 这行代码通常放在每个epoch的末尾,以便在每个epoch结束时更新学习率。
"""
下面我们把训练循环的代码封装到一个fit()函数中
"""
def fit(epochs,model,train_dl,test_dl,loss_fn,optimizer,exp_lr_scheduler=None): train_loss=[]train_acc=[]test_loss=[]test_acc=[]for epoch in range(epochs):# 调用train()函数训练epoch_loss,epoch_acc=train(train_dl,model,loss_fn,optimizer)# 调用test()函数测试epoch_test_loss,epoch_test_acc=test(test_dl,model)# 记录训练集和测试集的损失和准确率train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)# 更新学习率if exp_lr_scheduler is not None:exp_lr_scheduler.step()# is not None 是判断变量是否存在的标准写法# 定义一个打印模板template=("epoch:{:2d},train_loss:{:5f},train_acc:{:.1f}%,""test_loss:{:.5f},test_acc:{:.1f}%")# 打印训练集和测试集的损失和准确率print(template.format(epoch+1,epoch_loss,epoch_acc*100,epoch_test_loss,epoch_test_acc*100))print("Done")return train_loss,train_acc,test_loss,test_acc# 然后我们可以调用fit()函数训练模型:
train_loss,train_acc,test_loss,test_acc=fit(30,model,train_dl,test_dl,loss_fn,optimizer,exp_lr_scheduler)

⚠️ 参数说明

  • step_size=10:每10个epoch衰减一次
  • gamma=0.1:学习率衰减为原来的0.1倍

3.2 损失与准确率曲线

# 绘制训练曲线
plt.plot(range(1, epochs+1), train_loss, label='train_loss')
plt.plot(range(1, epochs+1), test_loss, label='test_loss')
plt.legend()
plt.show()

5. 总结

通过本实验,我们验证了三种关键优化技术的效果:

  1. Dropout:有效抑制过拟合,使测试准确率提升至93.8%
  2. 批标准化:加速模型收敛,提高训练稳定性
  3. 学习率衰减:避免训练后期跳过极值点,提升最终性能

💎 最佳实践:在卷积层后添加批标准化,全连接层后添加Dropout,配合学习率衰减策略,可获得最佳模型性能

这些优化技术在实际计算机视觉任务中具有广泛应用价值,能显著提升模型泛化能力和训练效率!🚀

http://www.xdnf.cn/news/1124551.html

相关文章:

  • 每天一个前端小知识 Day 31 - 前端国际化(i18n)与本地化(l10n)实战方案
  • 分支战略论:Git版本森林中的生存法则
  • PHP password_get_info() 函数
  • 时序预测 | Pytorch实现CNN-LSTM-KAN电力负荷时间序列预测模型
  • 深入理解MyBatis延迟加载:原理、配置与实战优化
  • 设备发出、接收数据帧的工作机制
  • B站自动回复工具(破解)
  • Linux连接跟踪Conntrack:原理、应用与内核实现
  • JAVA进阶--JVM
  • 【Linux网络】:HTTP(应用层协议)
  • rk3588平台USB 3.0 -OAK深度相机适配方法
  • 网络编程(TCP连接)
  • 前端同学,你能不能别再往后端传一个巨大的JSON了?
  • 7.14练习案例总结
  • UE5多人MOBA+GAS 22、创建技能图标UI,实现显示蓝耗,冷却,以及数字显示的倒数计时还有雷达显示的倒数计时
  • C语言:20250714笔记
  • OFDM系统中关于信号同步的STO估计与CFO估计的MATLAB仿真
  • 学习笔记——农作物遥感识别与大范围农作物类别制图的若干关键问题
  • 网络编程(套接字)
  • HTML应用指南:利用GET请求获取河南省胖东来超市门店位置信息
  • win10安装Elasticsearch
  • iOS高级开发工程师面试——RunTime
  • 深度解读virtio:Linux IO虚拟化核心机制
  • 一种用于医学图像分割的使用了多尺寸注意力Transformer的混合模型: HyTransMA
  • 记录自己在将python文件变成可访问库文件是碰到的问题
  • Linux的相关学习
  • JavaScript进阶篇——第一章 作用域与垃圾回收机制
  • 2025 R3CTF
  • 日语学习-日语知识点小记-构建基础-JLPT-N3阶段(4):语法+单词+復習+发音
  • JS基础知识(上)