当前位置: 首页 > java >正文

Pytorch在FSDP模型中使用EMA

注:本文章方法只在Pytorch FSDP1的模型上实验过,且切分策略为SHARDED_STATE_DICT场景。

使用FSDP对模型权重切分后如何使用EMA网上搜了一圈没找到个一个靠谱的办法,干脆自己写一个算了,实现代码如下:

import os
from typing import Dict, List
from collections import defaultdictimport torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.default_planner import DefaultSavePlannerclass ShardEMAModel:def __init__(self, fsdp_model: FSDP, decay: float = 0.999):assert isinstance(fsdp_model, FSDP)self.fsdp_model = fsdp_modelself.decay = decayself.shard_ema_state: Dict[str, List[torch.Tensor]] = defaultdict(list)shard_state = self._get_shard_state()for k, v in shard_state.items():for local_shard in v._local_shards:self.shard_ema_state[k].append(local_shard.tensor.clone())self.num_shard_params = sum([sum([t.numel() for t in v]) for v in self.shard_ema_state.values()])print(f"Shard EMA Model has {self.num_shard_params / 1e6:.3f}M params.")def _get_shard_state(self):with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):shard_state = self.fsdp_model.state_dict()return shard_state@torch.inference_mode()def update(self):"""update EMA Model shard weights"""shard_state = self._get_shard_state()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):self.shard_ema_state[k][idx].mul_(self.decay).add_(local_shard.tensor, alpha=1 - self.decay)def save_ema_shard_weights(self, save_dir: str):"""save EMA Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()for k, v in shard_state.items():for idx, local_shard in enumerate(v._local_shards):local_shard.tensor = self.shard_ema_state[k][idx]state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)def save_shard_weights(self, save_dir: str):"""save original FSDP Model shard weights"""with FSDP.state_dict_type(self.fsdp_model, StateDictType.SHARDED_STATE_DICT):os.makedirs(save_dir, exist_ok=True)shard_state = self.fsdp_model.state_dict()state_dict = {"model": shard_state}dist_cp.save(state_dict=state_dict,storage_writer=dist_cp.FileSystemWriter(save_dir),planner=DefaultSavePlanner(),)

使用示例:

# create FSDP Model and EMA Model
fsdp_model = FSDP(...)
ema_model = ShardEMAModel(fsdp_model, decay=0.99)# train fsdp model and optimizer weights
...# update EMA Model shard weights
ema_model.update()# save EMA Model shard weights
ema_model.save_ema_shard_weights("save_path")
http://www.xdnf.cn/news/17830.html

相关文章:

  • 考研408《计算机组成原理》复习笔记,第四章(3)——指令集、汇编语言
  • 14、C 语言联合体和枚举知识点总结
  • Linux系统Namespace隔离实战:dd/mkfs/mount/unshare命令组合应用
  • 报数游戏(我将每文更新tips)
  • 2022 年全国硕士研究生招生考试真题笔记
  • 杂记 01
  • elasticsearch基础概念与集群部署
  • Blender模拟结构光3D Scanner(一)外参数匹配
  • ARM芯片架构之CoreSight Channel Interface 介绍
  • 20250813测试开发岗(凉)面
  • Spring Security 前后端分离场景下的会话并发管理
  • 商品分类拖拽排序设计
  • 数据结构:队列(Queue)与循环队列(Circular Queue)
  • 【SpringBoot系列-01】Spring Boot 启动原理深度解析
  • 【OpenGL】LearnOpenGL学习笔记07 - 摄像机
  • 《设计模式之禅》笔记摘录 - 15.观察者模式
  • 分布式与微服务宝典
  • Redis基础命令
  • 电商项目微服务架构拆分实战
  • LangGraph 指南篇-基础控制
  • 2025盛夏AI热浪:八大技术浪潮重构数字未来
  • HTML第三次作业
  • C语言相关简单数据结构:顺序表
  • 【深入浅出STM32(1)】 GPIO 深度解析:引脚特性、工作模式、速度选型及上下拉电阻详解
  • IPC Inter-Process Communication(进程间通信)
  • 桌面运维如何深造
  • 算法篇----分治(归并排序)
  • Unity新手制作跑酷小游戏详细教程攻略
  • Python实现点云概率ICP(GICP)配准——精配准
  • 【金仓数据库产品体验官】_从实践看金仓数据库与 MySQL 的兼容性