当前位置: 首页 > backend >正文

Python Day44

Task:
1.预训练的概念
2.常见的分类预训练模型
3.图像预训练模型的发展史
4.预训练的策略
5.预训练代码实战:resnet18


1. 预训练的概念

预训练(Pre-training)是指在大规模数据集上,先训练模型以学习通用的特征表示,然后将其用于特定任务的微调。这种方法可以显著提高模型在目标任务上的性能,减少训练时间和所需数据量。

核心思想:

  • 在大规模、通用的数据(如ImageNet)上训练模型,学习丰富的特征表示。
  • 将预训练模型应用于任务特定的细调(fine-tuning),使模型适应目标任务。

优势:

  • 提升模型性能
  • 缩短训练时间
  • 需要较少的标注数据
  • 提供良好的特征初始化

2. 常见的分类预训练模型

常见的分类预训练模型主要包括:

模型名称提出年份特色与应用
AlexNet2012标志深度学习重返计算机视觉的起点
VGG(VGG16/19)2014简洁结构,深层网络,广泛用于特征提取
ResNet(Residual Network)2015引入残差连接,解决深层网络退化问题
Inception(GoogLeNet)2014多尺度特征提取,复杂模块设计
DenseNet2017密集连接,加深网络而不增加参数
MobileNet2017轻量级模型,适合移动端应用
EfficientNet2019根据模型宽度、深度和分辨率优化设计

这些模型在ImageNet等大规模数据集上预训练,成为计算机视觉各种任务的基础。


3. 图像预训练模型的发展史

  1. AlexNet (2012)
    首次使用深度卷积神经网络大规模应用于ImageNet,显著提升分类效果。

  2. VGG系列 (2014)
    简单堆叠卷积和池化层,深度逐步增加,提高表现。

  3. GoogLeNet/Inception (2014)
    引入Inception模块,进行多尺度特征提取,有效提升效率。

  4. ResNet (2015)
    通过残差连接解决深层网络的退化问题,使网络深度大幅提升(如ResNet-50,ResNet-101等)。

  5. DenseNet (2017)
    特色是密集连接,增强特征传播,改善梯度流。

  6. MobileNet, EfficientNet (2017-2019)
    追求轻量级和高效率,适应移动端和资源有限场景。

总的趋势:

  • 从浅层逐步向深层网络发展
  • 引入残差、密集连接等结构解决深层网络训练难题
  • 注重模型效率与性能平衡

4. 预训练的策略

常用的预训练策略包括:

1. 直接使用预训练模型进行微调(Fine-tuning)

  • 加载预训练权重
  • 替换最后的分类层以适应新任务(如类别数不同)
  • 选择性冻结部分层(如只训练最后几层)或全部训练

2. 特征提取(Feature Extraction)

  • 使用预训练模型的固定特征提取器,从中提取特征
  • 在这些特征基础上训练简单的分类器(如SVM或线性层)

3. 逐层逐步微调(Layer-wise Fine-tuning)

  • 先冻结底层特征层,只训练高层
  • 再逐步解冻低层,进行全层微调

4. 迁移学习(Transfer Learning)

  • 利用预训练模型迁移到相似领域任务中
  • 通过微调适应不同数据分布和任务需求

5. 预训练代码实战:ResNet18

以下是基于PyTorch框架的ResNet18预训练模型加载和微调的示例代码:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# 1. 加载预训练ResNet18模型
model = models.resnet18(pretrained=True)# 2. 替换分类层以适应新任务(比如有10个类别)
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)# 3. 冻结前面层,只训练最后的全连接层(可选)
for param in model.parameters():param.requires_grad = False  # 冻结所有参数# 只训练最后一层参数
for param in model.fc.parameters():param.requires_grad = True# 4. 定义数据变换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])# 5. 加载数据集
train_dataset = ImageFolder('path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = ImageFolder('path_to_val_data', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 6. 设置优化器(只优化可训练参数)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()# 7. 训练环节
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)for epoch in range(10):model.train()total_loss = 0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")# 8. 评估
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)
print(f'Validation Accuracy: {100 * correct / total:.2f}%')

总结

  • 预训练是一种利用大规模数据学习通用特征,从而在目标任务中快速获得优秀表现的技术。
  • 常用的分类预训练模型包括ResNet、VGG、Inception等,发展经历了从浅层到深层、从视觉到效率的不断演变。
  • 预训练策略多样,适应不同场景,微调与特征提取是常用手段。
  • 实战中,可以利用PyTorch提供的模型接口快速加载预训练模型,并进行微调以满足具体需求。
http://www.xdnf.cn/news/12186.html

相关文章:

  • 数据可视化大屏案例落地实战指南:捷码平台7天交付方法论
  • 【达梦数据库】OOM问题排查思路
  • React 新项目
  • OGG-01635 OGG-15149 centos服务器远程抽取AIX oracle11.2.0.4版本
  • Spring框架学习day7--SpringWeb学习(概念与搭建配置)
  • Eureka REST 相关接口
  • 云原生思维重塑数字化基座:从理念到实践的深度剖析
  • Python基于蒙特卡罗方法实现投资组合风险管理的VaR与ES模型项目实战
  • Django CMS 的 Demo
  • 每日算法 -【Swift 算法】三数之和最接近目标值
  • Golang——9、反射和文件操作
  • Redis:介绍和认识,通用命令,数据类型和内部编码,单线程模型
  • 深入浅出玩转物联网时间同步:基于BC260Y的NTP实验与嵌入式仿真教学革命
  • 从《现实不似你所见》探寻与缘起性空的思想交织
  • MySQL间隙锁入手,拿下间隙锁面试与实操
  • [原创](现代Delphi 12指南):[macOS 64bit App开发]: TTask创建多线程, 更简单, 更快捷.
  • 报告精读:“数据银行”概念模型与建设规划研究报告【附全文阅读】
  • JavaSec-SSTI - 模板引擎注入
  • 【ArcGIS应用】ArcGIS‌应用如何进行影像分类?
  • adb 连不上真机设备问题汇总
  • ros2--图像/image
  • halcon c# 自带examples报错 Matching
  • JVM中的各类引用
  • 设计模式域——软件设计模式全集
  • GIC流协议接口
  • android 之 Tombstone
  • 巴科斯-诺尔范式与抽象语法树:CMake语法实例教程
  • 深入学习RabbitMQ队列的知识
  • RabbitMQ实用技巧
  • 18650锂电池组点焊机:高效组装锂电池的关键工具|比斯特自动化