当前位置: 首页 > ops >正文

【pytorch】加载部分模型参数及冻结部分参数

前言

结论:1)部分加载模型参数的关键就是自定义选取训练好的模型的state_dict的键值对,然后更新到需要加载模型参数的新模型的state_dict中。 2)冻结部分参数的关键就是自定义设置需冻结的参数的requires_grad属性值为False,并在优化器中传入参数时,过滤掉requires_grad=False的这部分参数,使其不参与更新。 下文通过实例记录如何在pytorch中只加载部分参数,及冻结部分参数进行训练。
 
背景如下,有一个基于bert的内容分类器和基于bert的序列标注分类器,以下简称classifier与identifier。两个模型在特征抽取部分都采用相同的网络结构,就是bare_bert部分,但两种分类器的类型数量不一样,除此之外,bare_bert后续的网络结构也不同,最终需求为:“先训练classifier,然后将classifier的特征抽取部分bare_bert的参数共享给identifier,然后单独训练identifier。在训练identifier时,其共享的bare_bert的参数保持不变,只更新其层的网络参数。”
 

下面给了classifier与identifier的人造例子,有助于理解,classifier的定义如下:
在这里插入图片描述
classifier各层网络参数值为:
在这里插入图片描述

identifier定义如下:
在这里插入图片描述
identifier各层的参数值为:
在这里插入图片描述

可以发现,“fc1.weight, fc1.bias, fc2.weight, fc2.bias”是classifier与identifier都共有的,但具体值不相同,这里假设classifier是训练完之后的参数值,下面记录如何将共有的参数值共享给identifier。
 

共享部分参数

共享模型参数的核心工具就是模型的state_dict方法与load_state_dict方法。状态字典本质是python中的有序字典。

# *************** 自定义取出需要共享的参数 *******************
from collections import OrderedDict
temp = OrderedDict()ide_state_dict = identifier.state_dict(destination=None)
for name, parameter in classifier.named_parameters():if name in ide_state_dict:temp[name] = parameter# ************** 将共享的参数更新到需训练的模型中 ****************
ide_state_dict.update(temp)  # 更新参数值
identifier.load_state_dict(ide_state_dict)

此时再查看identifier的参数值,可以发现“fc1.weight, fc1.bias, fc2.weight, fc2.bias”部分的参数值已经是classifier的了,注意此时的“requires_grad=True”,下面记录如何冻结部分参数。
在这里插入图片描述

冻结部分参数

模型中的Parameter本质是Tensor的子类,因此其有Tensor的所有属性,其中“requires_grad”属性决定是否需要计算梯度,默认情况下,模型网络中的参数是需要记录梯度的,因此需要将“requires_grad”设置为False,并且在优化器中过滤掉这部分参数。如下所示:

# 自定义冻结部分参数
for name, parameter in identifier.named_parameters():if 'ide_only' not in name:parameter.requries_grad = False# 过滤传入优化器的参数
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, identifier.parameters()))

 

参考资料

pytorch只加载预训练模型中的部分参数及冻结部分参数

Pytorch中,只导入部分模型参数的做法

How the pytorch freeze network in some layers, only the rest of the training?

pytorch冻结部分参数训练另一部分

http://www.xdnf.cn/news/11650.html

相关文章:

  • 经典组合 | PTS + ARMS打造性能和应用诊断利器
  • Unix操作系统的前世今生
  • 盘点世界上最出名的十大黑客(每个都能改变历史的大神人物)
  • 社会工程学工具
  • 2023美赛F题讲解+数据领取
  • 电能表原理
  • 科普:黑客入侵网站究竟是怎么回事?
  • 搜狗浏览器,添加自定义搜索引擎~
  • 【java毕业设计】基于javaEE+原生Servlet+MySql的村镇旅游网站设计与实现(毕业论文+程序源码)——村镇旅游网站
  • MessageBox()简易对话框的用法
  • c216芯片组服务器,几无改变 9系芯片组架构及新功能_Intel主板_主板评测-中关村在线...
  • C语言手搓游戏之经典《推箱子》
  • 【面试重点系列】操作系统常见面试重点题(万字图解)
  • httpUnit介绍及使用示例
  • Wavesplit: End-to-End Speech Separation by Speaker Clustering
  • 静态网页设计html css——HTML+CSS+JavaScript魔域私服游戏HTML(1个页面)
  • 2024年最全kali无线渗透之WEP加密模式与破解13_wep加密过程详解,作为网络安全开发程序员
  • matlab 假设检验
  • linux 命令总结
  • Win_XP_SP3系统下成功安装WinccV6.0_SP3a 经验分享
  • VR全景图片如何制作?揭秘VR全景图片制作全流程
  • 快速上手jQuery:样式操作、效果
  • Cortex简介
  • IPX
  • SCCM安装:(1)准备工作
  • 哈希表-数据结构(C语言)
  • Springboot集成OpenOffice实现各类文件转PDF,在线预览
  • KMP算法详解及各种应用
  • code::blocks代码及信息回顾
  • malloc 的实现原理