当前位置: 首页 > news >正文

Pytorch 版本的lookahead 优化函数使用(附代码)

Lookahead 优化算法是Adam的作者继Adam之后的又一力作,论文可以参见https://arxiv.org/abs/1907.08610

这篇博客先不讲述Lookahead具体原理,先介绍如何将Lookahead集成到现有的代码中。

本人在三个项目中(涉及风格转换、物体识别)使用该优化器,最大的感受就是使用该优化器十分有利于模型收敛,原本不收敛或者收敛过慢的模型在使用lookahead后可以看到明显的收敛情况,并且最终的效果能够满足最初设计的要求。

总所周知,Adam因为其具有较好的适应性,被广泛用于各类模型的优化;其参数简单,调参方便的特点一直为大家所喜爱,尤其对于初学者较为友好。Lookahead 也继承了Adam的优点。lookahead的Pytorch版本代码如下所示:后续会针对代码进行原理讲解,该代码在Github上可以找到。

from collections import defaultdict
from torch.optim import Optimizer
import torchclass Lookahead(Optimizer):def __init__(self, optimizer, k=5, alpha=0.5):self.optimizer = optimizerself.k = kself.alpha = alphaself.param_groups = self.optimizer.param_groupsself.state = defaultdict(dict)self.fast_state = self.optimizer.statefor group in self.param_groups:group["counter"] = 0def update(self, group):for fast in group["params"]:param_state = self.state[fast]if "slow_param" not in param_state:param_state["slow_param"] = torch.zeros_like(fast.data)param_state["slow_param"].copy_(fast.data)slow = param_state["slow_param"]slow += (fast.data - slow) * self.alphafast.data.copy_(slow)def update_lookahead(self):for group in self.param_groups:self.update(group)def step(self, closure=None):loss = self.optimizer.step(closure)for group in self.param_groups:if group["counter"] == 0:self.update(group)group["counter"] += 1if group["counter"] >= self.k:group["counter"] = 0return lossdef state_dict(self):fast_state_dict = self.optimizer.state_dict()slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): vfor k, v in self.state.items()}fast_state = fast_state_dict["state"]param_groups = fast_state_dict["param_groups"]return {"fast_state": fast_state,"slow_state": slow_state,"param_groups": param_groups,}def load_state_dict(self, state_dict):slow_state_dict = {"state": state_dict["slow_state"],"param_groups": state_dict["param_groups"],}fast_state_dict = {"state": state_dict["fast_state"],"param_groups": state_dict["param_groups"],}super(Lookahead, self).load_state_dict(slow_state_dict)self.optimizer.load_state_dict(fast_state_dict)self.fast_state = self.optimizer.statedef add_param_group(self, param_group):param_group["counter"] = 0self.optimizer.add_param_group(param_group)

将lookahead集成在现有代码中如下操作即可:

base_optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
opt = Lookahead(base_optimizer, k=5, alpha=0.5)

此时直接将opt作为正常的优化器使用即可,就像直接使用Adam一样的步骤使用opt

 

http://www.xdnf.cn/news/807913.html

相关文章:

  • WML语言与编程
  • Java 枚举(enum)剖析
  • AI编程助手 Kodezi : 记录、分享一个 VS code 插件
  • 如何符合E-NCAP测试规范?TPT让AEB场景测试更简单:AEB系统的测试场景 | 测试执行与评估 | 测试用例渲染展示
  • 允许Traceroute探测
  • SQL中的distinct的使用方法
  • System.getProperty()方法获取系统变量
  • ubuntu 10.04 下载源列表
  • 《疯狂的站长》读后感1
  • 惊奇的发现37个上班族必看的网站,不看就OUT了
  • iPhone4s降级ios6.1.3流程总结
  • 分享几个普通人做私活赚外快的好地方
  • 大一python编程题库和答案,大一python程序设计考题
  • 值得收藏 Modbus RTU 协议详解
  • 火车头发布html模板,织梦V5.7火车头采集器全套Web发布模块(含软件模型、图集模型、商品模型)...
  • opendirve ,好用的免费直链(外链)网盘
  • 「营业日志 2020.11.26」一道纳什均衡数数题
  • Sky入围CCTV06体坛风云人物侯选名单
  • 素材类dedecms织梦模板免费下载
  • Visual Studio 6.0 企业版 下载
  • DOS操作系统
  • 一文搞定:whois数据库查询域名信息(WHOIS)
  • MSN登陆不了怎么办
  • iPad2 4.3.3完美越狱教程 一键即可操作
  • 在Android中使用SyncAdapter同步数据全攻略
  • HTB靶场 Perfection
  • U盘修复技巧大全
  • 支付宝的架构
  • QQ分享 QQ空间分享 API链接:
  • 黑客帝国之酷炫屏保数字雨