当前位置: 首页 > java >正文

详解迁移学习,模型参数冻结,优化器参数定义

目录

  • 前言
  • 冻结参数的原理
    • 模型参数全部冻结
    • 模型参数部分冻结
    • 冻结预训练模型部分层并修改最后一层
  • 优化器
    • 优化器的作用
    • 直接传入参数生成器或参数列表
    • 传入参数组字典列表

前言

在使用预训练模型的时候能因为某些原因无需对模型的参数进行全部更新,这时候我们可能希望将其部分参数或者全部参数进行冻结。这里简单总结一些冻结模型参数的方法。

冻结参数的原理

requires_grad 属性的作用
在 PyTorch 中,每个 torch.Tensor 都有一个 requires_grad 属性,这个属性决定了该张量在反向传播过程中是否需要计算梯度。当 requires_gradTrue 时,PyTorch 会跟踪该张量上的所有操作,并在调用 backward() 方法时计算其梯度;当 requires_gradFalse 时,PyTorch 不会跟踪该张量的操作,也就不会计算其梯度。
在神经网络中,模型的参数本质上就是 torch.Tensor 对象,通过设置参数的 requires_grad 属性,我们可以控制哪些参数参与训练(即计算梯度并更新),哪些参数不参与训练(即冻结)。

模型参数全部冻结

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 1)
)for param in model.parameters():param.requires_grad = False
  • 原理:model.parameters() 是一个生成器,返回的是一个生成器对象,它会迭代返回模型中所有需要计算梯度的参数(即 requires_grad 为 True 的参数)。通过遍历这个生成器,将每个参数的 requires_grad 属性设置为 False,就可以实现整体冻结模型参数的目的。

模型参数部分冻结

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3)self.fc1 = nn.Linear(16 * 10 * 10, 10)def forward(self, x):x = self.conv1(x)x = x.view(-1, 16 * 10 * 10)x = self.fc1(x)return xmodel = MyModel()
for name, param in model.named_parameters():if 'conv' in name:param.requires_grad = False
  • 原理:model.named_parameters() 也是一个生成器,返回的也是一个生成器对象,每次迭代返回一个元组 (name, param),其中 name 是参数的名称(字符串类型),param 是对应的参数张量(torch.Tensor 类型)。通过检查参数名称中是否包含特定的字符串(如 ‘conv’),我们可以筛选出需要冻结的参数,并将其 requires_grad 属性设置为 False

冻结预训练模型部分层并修改最后一层

from torchvision.models import resnet18
import torch.nn as nnmodel = resnet18(pretrained=True)for param in model.parameters():param.requires_grad = Falsenum_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
  • 原理:首先,通过遍历 model.parameters() 并将所有参数的 requires_grad 属性设置为 False,冻结了整个预训练模型的参数。然后,我们获取原最后一层全连接层的输入特征数 num_ftrs,并创建一个新的全连接层 nn.Linear(num_ftrs, 1) 替换原最后一层。新创建的层的参数的 requires_grad 属性默认是 True,因此只有这一层的参数会在训练时计算梯度并更新。

优化器

上面介绍了怎么冻结模型参数,冻结参数之后需要对模型的参数进行优化,常见的有两种参数方式,下面也一并介绍下。

优化器的作用

优化器的作用是根据计算得到的梯度来更新模型的参数,以最小化损失函数。在 PyTorch 中,常见的优化器有 SGD、Adam
等,它们都需要传入需要更新的参数。

直接传入参数生成器或参数列表

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 1)
)optimizer = optim.SGD(model[-1].parameters(), lr=0.001, momentum=0.9)
  • 原理:当我们直接将参数生成器(如 model[-1].parameters())或参数列表传递给优化器时,优化器会将这些参数视为一组,并使用相同的优化超参数(如学习率、动量等)来更新这些参数。在这个例子中,我们只传递了模型最后一层的参数,因此优化器只会更新最后一层的参数。
  • 注意:model[-1].parameters() 返回一个生成器对象,它会迭代返回模型最后一层的所有参数。

传入参数组字典列表

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 1)
)first_layer_params = list(model[0].parameters())
last_layer_params = list(model[2].parameters())param_list = [{'params': first_layer_params, 'lr': 0.0001},{'params': last_layer_params, 'lr': 0.001}
]optimizer = optim.Adam(param_list)
  • 原理:当我们需要为不同的参数组设置不同的优化超参数时,可以传入一个包含参数组字典的列表。每个字典代表一个参数组,其中 ‘params’ 键对应的值是该参数组的参数列表或参数生成器,还可以包含其他键值对来指定该参数组的特定超参数(如学习率、权重衰减等)。优化器会根据这些参数组的设置,对不同的参数组使用不同的超参数进行优化。
http://www.xdnf.cn/news/3833.html

相关文章:

  • 传感器数据处理笔记
  • Linux中的粘滞位和开发工具和文本编辑器vim
  • 马小帅面试遇“灵魂拷问“
  • hot100:链表倒数k个节点- 力扣(LeetCode)
  • 研0大模型学习(第11天)
  • FFT实现(Cooley-Tukey算法)
  • WEB 前端学 JAVA(二)Java 的发展与技术图谱简介
  • TS 字面量类型
  • Mybatis学习(下)
  • LabVIEW开发风量智能监测系统
  • 【杂谈】-探索 NVIDIA Dynamo 的高性能架构
  • 牛客周赛90 C题- Tk的构造数组 题解
  • STM32智能垃圾桶:四种控制模式实战开发
  • 58认知干货:创业经验分享及企业形式的汇总
  • 【AI面试准备】逻辑思维、严谨性、总结能力、沟通协作、适应力与目标导向
  • 文件一键解密软件工具(支持pdf、word、excel、ppt、rar、zip格式文件)
  • 链接文件及功能安全:英飞凌官方文档摘录 - Tasking链接文件
  • 开上“Python跑的车”——自动驾驶数据可视化的落地之道
  • 使用python写多文件#inlcude
  • Spring AI Advisors API:AI交互的灵活增强利器
  • ES6入门---第三单元 模块三:async、await
  • 网络:TCP三次握手、四次挥手
  • 介词:连接名词与句子其他成分的桥梁
  • 互联网大厂Java面试:从基础到实战
  • 【漫话机器学习系列】239.训练错误率(Training Error Rate)
  • vulkanscenegraph显示倾斜模型(6.4)-多线程下的记录与提交
  • Dalvik虚拟机和ART虚拟机
  • ART 下 Dex 加载流程源码分析 和 通用脱壳点
  • 【ArcGIS微课1000例】0145:如何按照自定义形状裁剪数据框?
  • 学习黑客Linux权限