当前位置: 首页 > news >正文

深度学习之第八课迁移学习(残差网络ResNet)

目录

简介

一、迁移学习

1.什么是迁移学习

2. 迁移学习的步骤

二、残差网络ResNet

1.了解ResNet

2.ResNet网络---残差结构

三、代码分析

1. 导入必要的库

2. 模型准备(迁移学习)

3. 数据预处理

4. 自定义数据集类

5. 数据加载器

6. 设备配置

7. 训练函数

8. 测试函数

9. 训练配置和执行

整体流程总结


简介

        经过长久的卷积神经网络的学习、我们学习了如何提高模型的准确率,但是最终我们的准确率还是没达到百分之八十。原因是因为我们本身模型的局限,面对现有很多成熟的模型,它们有很好的效果,都是经过多次训练选取了最佳的参数,那我们能不能去使用哪些大佬的模型呢?

        答案是可以的,这就使用到迁移学习的知识。

深度学习之第五课卷积神经网络 (CNN)如何训练自己的数据集(食物分类)

深度学习之第六课卷积神经网络 (CNN)如何保存和使用最优模型

深度学习之第七课卷积神经网络 (CNN)调整学习率

一、迁移学习

1.什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2. 迁移学习的步骤

        1、选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

        2、冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

        3、在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

        4、微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

        5、评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

太多概念,我们直接使用残差网络进行迁移学习。

二、残差网络ResNet

1.了解ResNet

        ResNet 网络是在 2015年 由微软实验室中的何凯明等几位大神提出,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。获得COCO数据集中目标检测第一名,图像分割第一名。

传统卷积神经网络存在的问题?

卷积神经网络都是通过卷积层和池化层的叠加组成的。 在实际的试验中发现,随着卷积层和池化层的叠加,学习效果不会逐渐变好,反而出现2个问题:

        1、梯度消失和梯度爆炸 梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0 梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

        2、退化问题

如何解决问题?

为了解决梯度消失或梯度爆炸问题,论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。 为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。

                                        实线为测试集错误率 虚线为训练集错误率

2.ResNet网络---残差结构

ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种的基本结构不同于ResNet-18和ResNet-34,属于更深层的网络。

不论是多少层的ResNet网络,它们都有以下共同点:

  • 网络一共包含5个卷积组,每个卷积组中包含1个或多个基本的卷积计算过程(Conv-> BN->ReLU)
  • 每个卷积组中包含1次下采样操作,使特征图大小减半,下采样通过以下两种方式实现:
    • 最大池化,步长取2,只用于第2个卷积组(Conv2_x)
    • 卷积,步长取2,用于除第2个卷积组之外的4个卷积组
  • 第1个卷积组只包含1次卷积计算操作,5种典型ResNet结构的第1个卷积组完全相同,卷积核均为7x7, 步长为均2
  • 第2-5个卷积组都包含多个相同的残差单元,在很多代码实现上,通常把第2-5个卷积组分别叫做Stage1、Stage2、Stage3、Stage4
  • 首先是第一层卷积使用kernel 7∗7,步长为2,padding为3。之后进行BN,ReLU和maxpool。这些构成了第一部分卷积模块conv1。
  • 然后是四个stage,有些代码中用make_layer()来生成stage,每个stage中有多个模块,每个模块叫做building block,resnet18= [2,2,2,2],就有8个building block。注意到他有两种模块BasicBlockBottleneck。resnet18和resnet34用的是BasicBlock,resnet50及以上用的是Bottleneck。无论BasicBlock还是Bottleneck模块,都用到了残差连接(shortcut connection)方式:

下图以ResNet18为例介绍一下它的网络模型

layer1

        ResNet18 ,使用的是 BasicBlocklayer1,特点是没有进行降采样,卷积层的 stride = 1,不会降采样。在进行 shortcut 连接时,也没有经过 downsample 层。

layer2,layer3,layer4

而 layer2layer3layer4 的结构图如下,每个 layer 包含 2 个 BasicBlock,但是第 1 个 BasicBlock 的第 1 个卷积层的 stride = 2,会进行降采样。在进行 shortcut 连接时,会经过 downsample 层,进行降采样和降维

        residual结构使用了一种shortcut的连接方式,也可理解为捷径。让特征矩阵隔层相加,注意F(X)和X形状要相同,所谓相加是特征矩阵相同位置上的数字进行相加。

        一个残差块有2条路径 F(x)和 x,F(x) 路径拟合残差,可称之为残差路径; 路径为`identity mapping`恒等映射,可称之为`shortcut`。图中的⊕为`element-wise addition`,要求参与运算的F(x)  和 x的尺寸要相同。

其中关键技术 Batch Normalization是对每一个卷积后进行标准化

        Batch Normalization目的:使所有的feature map满足均值为0,方差为1的分布规律

三、代码分析

1. 导入必要的库

import torch
from torch.utils.data import DataLoader,Dataset  # 数据加载相关
from PIL import Image  # 图像处理
from torchvision import transforms  # 数据预处理
import numpy as np
from torch import nn  # 神经网络模块
import torchvision.models as models  # 预训练模型

2. 模型准备(迁移学习)

这部分是迁移学习的重点,

# 加载预训练的ResNet-18模型
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)# 冻结所有预训练参数(迁移学习常用策略)
for param in resnet_model.parameters():print(param)  # 打印参数(实际应用中可删除)param.requires_grad = False  # 冻结参数,不参与训练# 获取原模型最后一层的输入特征数
in_features = resnet_model.fc.in_features  # ResNet18的fc层输入是512# 替换最后一层全连接层,输出类别数为20(根据实际任务调整)
resnet_model.fc = nn.Linear(in_features, 20)# 收集需要更新的参数(只有新替换的全连接层参数)
params_to_update = []
for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)

        这里采用了迁移学习策略:冻结预训练模型的大部分参数,只训练最后一层的分类器,这样可以加快训练速度并提高效果。

  • models.resnet18():创建 ResNet-18 网络结构
  • weights=models.ResNet18_Weights.DEFAULT:使用在 ImageNet 数据集上预训练好的权重初始化模型
  • 迁移学习的关键操作:保留预训练模型学到的特征提取能力
  • requires_grad = False:告诉 PyTorch 不需要计算这些参数的梯度
  • 原 ResNet-18 用于 1000 类分类,这里替换为 20 类分类
  • 只训练新替换的全连接层参数,大大减少计算量

3. 数据预处理

data_transforms = {'train': transforms.Compose([  # 训练集的数据增强transforms.Resize([300, 300]),  # 调整大小transforms.RandomRotation(45),  # 随机旋转transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转transforms.ToTensor(),  # 转为Tensor# 归一化,使用ImageNet的均值和标准差transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([  # 验证集不做数据增强,只做必要处理transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.486], [0.229, 0.224, 0.225])]),
}

4. 自定义数据集类

class food_dataset(Dataset):  # 继承Dataset类def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []  # 存储图像路径self.labels = []  # 存储标签self.transform = transform# 从文件中读取图像路径和标签with open(file_path, 'r') as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):  # 返回数据集大小return len(self.imgs)def __getitem__(self, idx):  # 获取单个样本image = Image.open(self.imgs[idx])  # 打开图像if self.transform:  # 应用预处理image = self.transform(image)# 处理标签,转为Tensorlabel = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label

5. 数据加载器

# 创建训练集和测试集
train_data = food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='test.txt', transform=data_transforms['train'])  # 注意这里可能应该用'valid'# 创建数据加载器,用于批量加载数据
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

6. 设备配置

# 自动选择可用的计算设备(GPU优先)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")# 将模型移动到选定的设备
model = resnet_model.to(device)

7. 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 切换到训练模式batch_size_num = 1for X, y in dataloader:X, y = X.to(device), y.to(device)  # 将数据移动到设备# 前向传播pred = model.forward(X)loss = loss_fn(pred, y)  # 计算损失# 反向传播和参数更新optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数# 打印训练信息loss = loss.item()if batch_size_num % 64 == 0:print(f"loss: {loss:>7f} [number: {batch_size_num}]")batch_size_num += 1

8. 测试函数

best_acc = 0  # 记录最佳准确率def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 切换到评估模式test_loss, correct = 0, 0with torch.no_grad():  # 关闭梯度计算,节省内存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()# 计算正确预测的数量correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f"Test result:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")# 保存最佳模型global best_accif correct > best_acc:best_acc = correcttorch.save(model, 'best3.pt')  # 保存整个模型

9. 训练配置和执行

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失,适用于分类任务
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam优化器# 学习率调度器,每10个epoch学习率乘以0.5
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)# 训练轮次
epochs = 20
acc_s = []
loss_s = []# 开始训练
for t in range(epochs):print(f"Epoch {t+1}\n-----------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)scheduler.step()  # 更新学习率
print("Done!")
print(f"最佳的结果:\n Accuracy: {(100*best_acc):>0.1f}%")

整体流程总结

  1. 加载预训练的 ResNet-18 模型并修改最后一层以适应新任务
  2. 定义数据预处理和增强方法
  3. 创建自定义数据集类来读取图像和标签
  4. 设置训练设备(GPU 或 CPU)
  5. 定义训练和测试函数
  6. 配置优化器、损失函数和学习率调度器
  7. 执行多轮训练,每轮结束后在测试集上评估并保存最佳模型

最后我们都结果可以达到百分之90左右,效果得到很大的提升。

http://www.xdnf.cn/news/1470529.html

相关文章:

  • ChartGPT深度体验:AI图表生成工具如何高效实现数据可视化与图表美化?
  • RequestContextFilter介绍
  • 53.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--新增功能--集成短信发送功能
  • 《C++变量命名与占位:深入探究》
  • SDRAM详细分析—06 存储单元架构和放大器
  • RPC内核细节(转载)
  • 软件设计模式之单例模式
  • 实战:Android 自定义菊花加载框(带超时自动消失)
  • 微型导轨如何实现智能化控制?
  • 9.5 面向对象-原型和原型链
  • 【Linux】Linux 的 cp -a 命令的作用
  • 2025高教社数学建模国赛B题 - 碳化硅外延层厚度的确定(完整参考论文)
  • Overleaf教程+Latex教程
  • Anaconda下载安装及详细配置的保姆级教程【Windows系统】
  • excel里面店铺这一列的数据结构是2C【uniteasone17】这种,我想只保留前面的2C部分,后面的【uniteasone17】不要
  • MySQL 8.0.36 主从复制完整实验
  • S32K3平台ADC 应用说明
  • 无人机RTK模块技术要点与难点
  • GEO排名优化:迈向个性化与语义化搜索时代的智能策略
  • VMwaer虚拟机安装完Centos后无法联网问题
  • SQL时间过滤神器:DATE_SUB+between实战指南,告别硬编码日期!
  • React 组件基础与事件处理
  • 04 - 【HTML】- 常用标签(下篇)
  • Windows环境下实现GitLab与Gitee仓库代码提交隔离
  • 今天一天三面,明天加油DW!!!
  • Linux文件描述符详解
  • baml:为提示工程注入工程化能力的Rust类型安全AI框架详解
  • 【完整源码+数据集+部署教程】广告牌实例分割系统源码和数据集:改进yolo11-dysample
  • MySQL数据库备份攻略:从Docker到本地部署
  • JAiRouter 0.7.0 发布:一键开启 OpenTelemetry 分布式追踪,链路性能全掌握