当前位置: 首页 > ai >正文

预训练模型适应下游任务?模型参数Freezing 与 微调 !

文章目录

  • 为什么需要模型冻结?
  • 如何在Pytorch中实现模型冻结
  • 思考

Pytorch当中,"模型冻结( Model Freezing)"是指在训练过程中,固定模型中某些层的参数,使其不参与 梯度更新。这意味着这些被冻结的层的权重和偏置将保持不变,及时在反向传播时也不会被优化器修改

为什么需要模型冻结?

模型冻结在深度学习中一个非常重要的技术,尤其在以下场景中非常有用:

1.迁移学习(Transfer Learning

  • 这也是最常见也是最重要的应用。我们通常会使用在大规模数据集(如ImageNet)上预训练好的模型(如ResNet、VGG、BERT等)
  • 这些预训练模型的前几层(或大部分层)已经学习到了非常通用和有用的特征(如,图像的前几层可能学习到边缘、纹理等低级特征)
  • 当我们有一个小规模的、与预训练任务相关但不是完全相同的新任务时,我们可以冻结预训练模型的大部分层,只训练模型的末尾(通常是分类头或任务特定层)的新层

这样做的好处:

  • 加速训练:只更新少量参数,训练速度更快
  • 避免过拟合:对于小数据集,从头开始训练整个深度模型很容易过拟合,而冻结大部分层可以利用预训练知识,减少过拟合风险
  • 节省计算资源:无需计算所有参数的梯度
  • 避免过拟合:对于小数据集,从头开始训练整个深度模型很容易过拟合,而冻结大部分层可以利用预训练知识,减少过拟合风险
  • 节省计算资源:无需计算所有参数的梯度

2.分阶段训练(Staged Training

  • 在某些复杂的模型训练中,可能需要分阶段进行。例如,你可以先训练模型的一部分,使其收敛,然后冻结这部分,再训练模型的另一部分
  • 这有助于稳定训练过程,特别是在处理非常深或复杂的网络

3.特征提取器(Feature Extractor

  • 你可以将一个预训练模型的前几层作一个强大的特征提取器,你冻结这些层,然后用它们的输出训练一个简单的分类器或回归器

如何在Pytorch中实现模型冻结

  • Pytorch中,实现模型冻结的关键在于设置模型的参数requires_grad属性
  • 每个torch.nn.Module中的参数(例如,weightbias)都有一个requires_grad属性,它默认为True时,Pytorch的自动求导机制会为这个参数计算梯度;当它为False时,就不会计算梯度

实现冻结的步骤通常是:

  • 加载预训练模型(如果适用)
  • 遍历模型中的所有参数
    • 对于你想要冻结的层中的参数,将它们的requires_grad设置为False
  • 创建优化器:
    • 关键一步:在创建优化器时,只将requires_gradTrue的参数传递给优化器。如果你将requires_grad=False的参数也传递给优化器,Pytorch可能会报错,因为它期望优化器中的所有参数都有梯度
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models# 1. 加载一个预训练的ResNet模型
model = models.resnet18(pretrained=True)# 2. 冻结大部分层
# 遍历模型中的所有参数
for param in model.parameters():param.requires_grad = False# 3. 修改最后一层(分类头)以适应新任务
# ResNet的最后一层是全连接层 (fc)。
# 假设你的新任务有10个类别
num_ftrs = model.fc.in_features # 获取原始全连接层的输入特征数
model.fc = nn.Linear(num_ftrs, 10) # 替换为新的全连接层# 此时,新替换的 model.fc 层的参数默认 requires_grad=True
# 我们可以确认一下:
# for name, param in model.named_parameters():
#     print(f"Layer: {name}, requires_grad: {param.requires_grad}")# 4. 创建优化器,只优化未被冻结的参数
# 筛选出 requires_grad 为 True 的参数
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9)# 现在你可以进行训练了
# 只有 model.fc 层的参数会得到更新,其他层的参数保持不变。# 假设你想要在训练一段时间后,解冻部分或所有层进行微调(Fine-tuning)
# 示例:解冻所有层
# for param in model.parameters():
#     param.requires_grad = True
# # 重新创建优化器,因为参数组可能发生了变化
# optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9) # 学习率通常会调小
  • 注意事项:
    • model.eval()model.train()
      • 冻结参数和设置模型模式 (eval() 或 train()) 是两回事。
      • model.eval() 会将模型设置为评估模式,这会影响一些层(如 nn.BatchNormnn.Dropout)的行为。在 eval 模式下,BatchNorm 层会使用运行平均值和方差,Dropout 层会被禁用。
      • 即使你冻结了所有层,如果你仍然在训练循环中(即你仍然在计算损失和反向传播),你也可能需要使用 model.train(),除非你的目标是只进行推理。
      • 当进行迁移学习并冻结大部分层时,通常会保持被冻结的 BatchNorm 层处于 eval 模式(因为它们在大数据集上学习到的统计量通常更稳定),而 Dropout 层则根据你的训练阶段决定是否启用。

优化器参数组:
当你冻结或解冻参数时,你需要确保你的优化器只包含那些你希望更新的参数。最稳妥的做法是在冻结/解冻操作后重新初始化优化器,并传递正确的参数组。

思考

  • 对于一些特定的任务,当我们使用ResNet这些预训练好的模型,我们可以考虑先冻结全部的参数,如果效果不好,也就是在训练的过程中损失函数最终都不能停留在一个较低的水平,那我们就可以考虑解冻

Freezing的情况,也就是直接冻结,不更新,单纯把预训练模型作为特征提取器

训练数据加载器批次数量: 32
Starting Epoch 1/50
Epoch [1/50] Average Loss: 0.8064
Starting Epoch 2/50
Epoch [2/50] Average Loss: 0.6964
Starting Epoch 3/50
Epoch [3/50] Average Loss: 0.6353
Starting Epoch 4/50
Epoch [4/50] Average Loss: 0.5984
Starting Epoch 5/50
Epoch [5/50] Average Loss: 0.5814
Starting Epoch 6/50
Epoch [6/50] Average Loss: 0.5723
Starting Epoch 7/50
Epoch [7/50] Average Loss: 0.5689
Starting Epoch 8/50
Epoch [8/50] Average Loss: 0.5681
Starting Epoch 9/50
Epoch [9/50] Average Loss: 0.5643
Starting Epoch 10/50
Epoch [10/50] Average Loss: 0.5340
Starting Epoch 11/50
Epoch [11/50] Average Loss: 0.4907
Starting Epoch 12/50
Epoch [12/50] Average Loss: 0.4244
Starting Epoch 11/50
Epoch [11/50] Average Loss: 0.4907
Starting Epoch 12/50
Epoch [12/50] Average Loss: 0.4244
Epoch [11/50] Average Loss: 0.4907
Starting Epoch 12/50
Epoch [12/50] Average Loss: 0.4244
Starting Epoch 12/50
Epoch [12/50] Average Loss: 0.4244
Epoch [12/50] Average Loss: 0.4244

解冻全部参数,直接微调

训练数据加载器批次数量: 19
Starting Epoch 1/50
Starting Epoch 1/50
Epoch [1/50] Average Loss: 0.6509
Starting Epoch 2/50
Epoch [2/50] Average Loss: 0.5947
Starting Epoch 3/50
Epoch [3/50] Average Loss: 0.5958
Starting Epoch 4/50
Epoch [4/50] Average Loss: 0.5426
Starting Epoch 5/50
Epoch [5/50] Average Loss: 0.5101
Starting Epoch 6/50
Epoch [6/50] Average Loss: 0.4670
Starting Epoch 7/50
Epoch [7/50] Average Loss: 0.5027
Starting Epoch 8/50
Epoch [8/50] Average Loss: 0.4306
Starting Epoch 9/50
Epoch [9/50] Average Loss: 0.4411
Starting Epoch 10/50
Epoch [10/50] Average Loss: 0.3459
Starting Epoch 11/50
Epoch [11/50] Average Loss: 0.2989
Starting Epoch 12/50
Epoch [12/50] Average Loss: 0.3167

分析:

  • 同等情况下,在这个直接冻结预训练模型的情况下进行下游任务的时候,损失函数下降没有那么明显,并且,解冻之后的微调,会使得最后的损失函数可以<0.1,可以说,效果不好的时候,可以考虑解冻,更新预训练模型的参数
  • 当然,我们可以根据这个实际的损失函数,调整对应的epoch,因为,不能让这个模型过拟合,考虑一个训练的性价比
Epoch [27/50] Average Loss: 0.0687
Starting Epoch 28/50
Epoch [28/50] Average Loss: 0.0450
Starting Epoch 29/50
Epoch [29/50] Average Loss: 0.0425
Starting Epoch 30/50
Epoch [30/50] Average Loss: 0.0658
Starting Epoch 31/50
Epoch [31/50] Average Loss: 0.0583
Starting Epoch 32/50
Epoch [32/50] Average Loss: 0.0406
Starting Epoch 33/50
Epoch [33/50] Average Loss: 0.0415
Starting Epoch 34/50
Epoch [34/50] Average Loss: 0.0310
Starting Epoch 35/50
Epoch [35/50] Average Loss: 0.0239
Starting Epoch 36/50
Epoch [36/50] Average Loss: 0.0246
Starting Epoch 37/50
Epoch [37/50] Average Loss: 0.0262
Starting Epoch 38/50
Epoch [38/50] Average Loss: 0.0208
Starting Epoch 39/50
Epoch [39/50] Average Loss: 0.0269
Starting Epoch 40/50
Epoch [40/50] Average Loss: 0.0191
Starting Epoch 41/50
Epoch [41/50] Average Loss: 0.0231
Starting Epoch 42/50
Epoch [42/50] Average Loss: 0.0286
Starting Epoch 43/50
Epoch [43/50] Average Loss: 0.0314
Starting Epoch 44/50
Epoch [44/50] Average Loss: 0.0411
Starting Epoch 45/50
Epoch [45/50] Average Loss: 0.0273
Starting Epoch 46/50
Epoch [46/50] Average Loss: 0.0190
Starting Epoch 47/50
Epoch [47/50] Average Loss: 0.0317
Starting Epoch 48/50
Epoch [48/50] Average Loss: 0.0131
Starting Epoch 49/50
Epoch [49/50] Average Loss: 0.0148
Starting Epoch 50/50
Epoch [50/50] Average Loss: 0.0145
http://www.xdnf.cn/news/13693.html

相关文章:

  • 基于Jenkins与Kubernetes的系统化变更管理实践
  • 《前端面试题:call、apply、bind 区别》
  • 1.sql连接语句
  • 软件测试相关问题
  • 柑橘检测模型
  • 直白话 OAuth 2 流程
  • langchain runnables 概念指南
  • 2025年硬件实习/秋招面试准备
  • 小熊派开发板显示图片
  • 机器人导航中的高程图 vs 高度筛选障碍物点云投影 —— 如何高效处理避障问题?
  • Oracle 条件索引 case when 报错解决方案(APP)
  • HTTP 网络协议演进过程
  • 【Docker基础】Docker核心概念:容器(Container)与镜像(Image)的区别与联系
  • Vue3 计算属性 computed
  • 装饰器模式(Decorator Pattern)
  • 【深尚想】M74VHC1GT08DTT1G逻辑芯片安森美ON 工业/物联网首选 电子元器件解析
  • 第29节 Node.js Query Strings
  • Kotlin 中的继承/实现
  • 2025-06-13【api】阿里百炼api调用方法
  • HarmonysOS 模块化设计理念
  • Jsoup解析商品详情时,有哪些常见的标签和属性?
  • 网络安全之CTF专题赛RE题解
  • Python训练营打卡Day49
  • 在QtCreator中使用GitHubCopilot
  • UML和模式应用(软件分析设计与建模期末复习)
  • 华为:eSight网管平台使用snmp纳管交换机
  • 利用Snowflake与SNP Glue揭示数据集成新潜力
  • Ozon欧亚仓网战略解析与中国卖家机遇
  • GUI丝滑教程-python tinker
  • Middleware