当前位置: 首页 > news >正文

强化学习算法笔记【AMP】

文章目录

  • AMP简介
  • 算法解析
    • 主要参数
    • 计算优势函数
    • 算法更新
  • 代码实现
  • 参考资料

AMP简介

AMP是一种无模型、基于随机政策的政策梯度算法(通过GAIL和PPO的组合进行训练),用于基于物理的动画的反向学习。它使角色能够从大型非结构化数据集中模仿各种行为,而无需运动规划器或其他剪辑选择机制。

算法解析

主要参数

在这里插入图片描述
在这里插入图片描述


计算优势函数

在这里插入图片描述
伪代码中的公式用于计算每个时间步的优势值(Advantage),这是强化学习中用于指导策略更新的关键因素。

反向迭代:伪代码从最后一行(时间步)开始向前计算,这是因为在强化学习中,后续时间步的值对当前时间步的值有直接影响,反向计算可以更高效地利用这些信息。
在这里插入图片描述
在这里插入图片描述


算法更新

在这里插入图片描述
(1)更新参考运动数据集

  • 操作:收集一批大小为 amp_batch_size 的参考运动数据,并将其添加到数据集 M 中。
  • 目的:确保数据集 M 包含最新的参考运动数据,用于后续训练判别器。

(2)计算组合奖励
在这里插入图片描述

  • task_reward_weight:任务奖励的权重。
  • style_reward_weight:风格奖励的权重。
  • discriminator_reward_scale:对风格奖励进行缩放的系数。

(3)计算回报和优势值

在这里插入图片描述
(4)从经验重放区中采样小批量数据

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
(1)计算新的动作对数概率

在这里插入图片描述

(2)计算熵损失

在这里插入图片描述

(3)计算策略损失

在这里插入图片描述

(4)计算价值损失

在这里插入图片描述

在这里插入图片描述

(1)计算判别器的预测损失

  • logit_AMP:判别器对当前AMP状态S_AMP的预测

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

(2)判别器Logit正则化

在这里插入图片描述
在这里插入图片描述

(3)判别器梯度惩罚

在这里插入图片描述

(4)判别器权重衰减

在这里插入图片描述
(5)step
在这里插入图片描述

(6)学习率更新和更新AMP回放区

在这里插入图片描述

代码实现

# import the agent and its default configuration
from skrl.agents.torch.amp import AMP, AMP_DEFAULT_CONFIG# instantiate the agent's models
models = {}
models["policy"] = ...
models["value"] = ...  # only required during training
models["discriminator"] = ...  # only required during training# adjust some configuration if necessary
cfg_agent = AMP_DEFAULT_CONFIG.copy()
cfg_agent["<KEY>"] = ...# instantiate the agent
# (assuming a defined environment <env> and memory <memory>)
# (assuming defined memories for motion <motion_dataset> and <reply_buffer>)
# (assuming defined methods to collect motion <collect_reference_motions> and <collect_observation>)
agent = AMP(models=models,memory=memory,  # only required during trainingcfg=cfg_agent,observation_space=env.observation_space,action_space=env.action_space,device=env.device,amp_observation_space=env.amp_observation_space,motion_dataset=motion_dataset,reply_buffer=reply_buffer,collect_reference_motions=collect_reference_motions,collect_observation=collect_observation)
AMP_DEFAULT_CONFIG = {"rollouts": 16,                 # number of rollouts before updating"learning_epochs": 6,           # number of learning epochs during each update"mini_batches": 2,              # number of mini batches during each learning epoch"discount_factor": 0.99,        # discount factor (gamma)"lambda": 0.95,                 # TD(lambda) coefficient (lam) for computing returns and advantages"learning_rate": 5e-5,                  # learning rate"learning_rate_scheduler": None,        # learning rate scheduler class (see torch.optim.lr_scheduler)"learning_rate_scheduler_kwargs": {},   # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})"state_preprocessor": None,             # state preprocessor class (see skrl.resources.preprocessors)"state_preprocessor_kwargs": {},        # state preprocessor's kwargs (e.g. {"size": env.observation_space})"value_preprocessor": None,             # value preprocessor class (see skrl.resources.preprocessors)"value_preprocessor_kwargs": {},        # value preprocessor's kwargs (e.g. {"size": 1})"amp_state_preprocessor": None,         # AMP state preprocessor class (see skrl.resources.preprocessors)"amp_state_preprocessor_kwargs": {},    # AMP state preprocessor's kwargs (e.g. {"size": env.amp_observation_space})"random_timesteps": 0,          # random exploration steps"learning_starts": 0,           # learning starts after this many steps"grad_norm_clip": 0.0,              # clipping coefficient for the norm of the gradients"ratio_clip": 0.2,                  # clipping coefficient for computing the clipped surrogate objective"value_clip": 0.2,                  # clipping coefficient for computing the value loss (if clip_predicted_values is True)"clip_predicted_values": False,     # clip predicted values during value loss computation"entropy_loss_scale": 0.0,          # entropy loss scaling factor"value_loss_scale": 2.5,            # value loss scaling factor"discriminator_loss_scale": 5.0,    # discriminator loss scaling factor"amp_batch_size": 512,                  # batch size for updating the reference motion dataset"task_reward_weight": 0.0,              # task-reward weight (wG)"style_reward_weight": 1.0,             # style-reward weight (wS)"discriminator_batch_size": 0,          # batch size for computing the discriminator loss (all samples if 0)"discriminator_reward_scale": 2,                    # discriminator reward scaling factor"discriminator_logit_regularization_scale": 0.05,   # logit regularization scale factor for the discriminator loss"discriminator_gradient_penalty_scale": 5,          # gradient penalty scaling factor for the discriminator loss"discriminator_weight_decay_scale": 0.0001,         # weight decay scaling factor for the discriminator loss"rewards_shaper": None,         # rewards shaping function: Callable(reward, timestep, timesteps) -> reward"time_limit_bootstrap": False,  # bootstrap at timeout termination (episode truncation)"mixed_precision": False,       # enable automatic mixed precision for higher performance"experiment": {"directory": "",            # experiment's parent directory"experiment_name": "",      # experiment name"write_interval": "auto",   # TensorBoard writing interval (timesteps)"checkpoint_interval": "auto",      # interval for checkpoints (timesteps)"store_separately": False,          # whether to store checkpoints separately"wandb": False,             # whether to use Weights & Biases"wandb_kwargs": {}          # wandb kwargs (see https://docs.wandb.ai/ref/python/init)}
}

参考资料

https://skrl.readthedocs.io/en/latest/api/agents/amp.html

http://www.xdnf.cn/news/104941.html

相关文章:

  • 渗透测试中的信息收集:从入门到精通
  • 心智模式VS系统思考
  • 海外产能达产,威尔高一季度营收利润双双大增
  • 1.5软考系统架构设计师:架构师的角色与能力要求 - 超简记忆要点、知识体系全解、考点深度解析、真题训练附答案及解析
  • 【ROS2】机器人操作系统安装到Ubuntu简介
  • deepseek-php-client开源程序是强力维护的 PHP API 客户端,允许您与 deepseek API 交互
  • 第十五届蓝桥杯 2024 C/C++组 艺术与篮球
  • 【redis】哨兵模式
  • MACD红绿灯副图指标使用技巧,绿灯做多,MACD趋势线,周期共振等实战技术解密
  • 信息系统项目管理工程师备考计算类真题讲解六
  • DeepSeek+Mermaid:轻松实现可视化图表自动化生成(附实战演练)
  • 2025 Java 框架痛点全解析:如何避免性能瓶颈与依赖混乱
  • TI芯片ADS1299的代替品LHE7909其应用领域
  • kali安装切换jdk1.8.0_451java8详细教程
  • Docker配置带证书的远程访问监听
  • 一个关于相对速度的假想的故事-6
  • LeetCode每日一题4.23
  • Codeforces Round 1019 (Div. 2)(ABCD)
  • 【线段树】P1438 无聊的数列|普及+
  • Java Arrays工具类解析(Java 8-17)
  • Spark集群搭建之Yarn模式
  • 将十六进制字符串转换为二进制字符串的方法(Python,C++)
  • Linux内核编译全流程详解与实战指南
  • 汇编语言与二进制分析:从入门到精通的学习路径与实践指南
  • 对流对象的理解
  • 电商行业下的Java核心、Spring生态与AI技术问答
  • MsQuick编译和使用
  • postman 删除注销账号
  • 一种免费的离线ocr-汉字识别率100%
  • 【每日八股】复习 Redis Day2:Redis 的持久化(下)