当前位置: 首页 > news >正文

第四章、SKRL(2): API(Models and Model instantiators)

0 前言

官方文档:https://skrl.readthedocs.io/en/latest/api/models.html

模型(或智能体模型)是指智能体的策略、值函数等的表示,智能体使用这些策略、值函数等来做出决策。智能体可以有一个或多个模型,它们的参数由优化算法来调整。
在这里插入图片描述

1 继承 Model类

# 导入类型提示相关模块
from typing import Optional, Union, Mapping, Sequence, Tuple, Any
# 导入强化学习环境库(OpenAI Gym的继任者)
import gymnasium
# 导入PyTorch深度学习框架
import torch
# 从skrl库中导入PyTorch模型基类
from skrl.models.torch import Model
# 定义自定义模型类,继承自skrl的Model基类
class CustomModel(Model):def __init__(self,# 观测空间:可以是整数、整数序列或gymnasium.Space对象observation_space: Union[int, Sequence[int], gymnasium.Space],# 动作空间:可以是整数、整数序列或gymnasium.Space对象action_space: Union[int, Sequence[int], gymnasium.Space],# 设备:可选参数,指定模型运行的设备(CPU/GPU)device: Optional[Union[str, torch.device]] = None) -> None:"""自定义强化学习模型:param observation_space: 观测/状态空间或形状。模型的`num_observations`属性将包含该空间的大小:type observation_space: int, sequence of int, gymnasium.Space:param action_space: 动作空间或形状。模型的`num_actions`属性将包含该空间的大小:type action_space: int, sequence of int, gymnasium.Space:param device: 分配或将要分配PyTorch张量的设备(默认:`None`)。如果为None,将自动选择可用设备(优先使用CUDA):type device: str or torch.device, optional"""# 调用父类构造函数初始化基础属性super().__init__(observation_space, action_space, device)# =====================================# - 此处应定义自定义属性和初始化操作#   例如神经网络层、初始化方法等# =====================================def act(self,# 输入字典:包含状态等输入数据inputs: Mapping[str, Union[torch.Tensor, Any]],# 角色标识:用于区分不同功能的模型实例role: str = "") -> Tuple[torch.Tensor, Union[torch.Tensor, None], Mapping[str, Union[torch.Tensor, Any]]]:"""根据输入状态生成动作:param inputs: 模型输入字典,常见键包括:- "states": 用于决策的环境状态- "taken_actions": 已采取的动作(可选):type inputs: dict,值通常为torch.Tensor:param role: 模型的角色标识(例如区分策略模型和值函数模型):type role: str, optional:return: 元组包含:1. 动作张量2. 动作对数概率(随机策略)或None(确定性策略)3. 额外信息字典:rtype: tuple[torch.Tensor, torch.Tensor或None, dict]"""# ==============================# - 此处应实现具体的动作生成逻辑#   示例步骤:#   1. 从inputs中获取状态#   2. 通过网络前向传播#   3. 计算动作及概率# ==============================# 示例伪代码:# states = inputs["states"]# action_logits = self.net(states)# action_dist = torch.distributions.Categorical(logits=action_logits)# actions = action_dist.sample()# return actions, action_dist.log_prob(actions), {}# 当前未实现具体逻辑,需要补充raise NotImplementedError("act方法需要具体实现")

该部分给出继承自Model类各种不同的模型,总结为下述表格

模型分类作用域训练的模型具体算法模型
Tabular model(表格模型)离散EpilonGreedyPolicyQ-learning、SARSA
Categorical model(分类模型)离散MLP、CNN、RNN、GRU、LSTM(输入是状态输出是离散动作概率分布)根据具体情况来看动作空间是连续的还是离散的以及states的输入是什么
Multi-Categorical model(多分类模型)离散MLP、CNN、RNN、GRU、LSTM(输入是状态输出是离散动作概率分布)根据具体情况来看动作空间是连续的还是离散的以及states的输入是什么
Gaussian mode(高斯模型)连续MLP、CNN、RNN、GRU、LSTM(输入是状态输出是动作,但这里会根据高斯分布做采样)根据具体情况来看动作空间是连续的还是离散的以及states的输入是什么
Multi-Gaussian mode(多元高斯模型)连续MLP、CNN、RNN、GRU、LSTM(输入是状态输出是动作,但这里会根据高斯分布做采样)根据具体情况来看动作空间是连续的还是离散的以及states的输入是什么
Deterministic Model(确定性模型)连续MLP、CNN、RNN、GRU、LSTM(输入是状态输出是动作,不做采样是确定的。)根据具体情况来看动作空间是连续的还是离散的以及states的输入是什么
Shared model(共享模型)在共享参数时使用比如critic-actor模型的策略网络和价值网络一样

2 基于Model instantiators的模型实例化

该API接口可以通过参数设置直接实例化模型,可以通过指定输入、隐藏层和激活函数来指定。

network=[{"name": <NAME>,  # container name"input": <INPUT>,  # container input (certain operations are supported)"layers": [  # list of supported layers<LAYER 1>,...,<LAYER N>,],"activations": [  # list of supported activation functions<ACTIVATION 1>,...,<ACTIVATION N>,],},
]

具体的实例化

models = {} #这是一个字典
models["q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS")
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS")
http://www.xdnf.cn/news/541837.html

相关文章:

  • 银行反欺诈理论、方法与实践总结(下):解决方案
  • 【动手学深度学习】1.1~1.2 机器学习及其关键组件
  • 珈和科技贺李德仁院士荣膺国际数字地球学会会士:以时空智能赋能可持续发展目标 绘就数字地球未来蓝图
  • 基于pycharm,python,flask,tensorflow,keras,orm,mysql,在线深度学习sql语句检测系统
  • HarmonyOS5云服务技术分享--云缓存快速上手指南
  • 创建型:建造者模式
  • 跨域_Cross-origin resource sharing
  • SpringBoot-6-在IDEA中配置SpringBoot的Web开发测试环境
  • Spring Boot 多参数统一加解密方案详解:从原理到实战
  • 物流项目第三期(统一网关、工厂模式运用)
  • 普通人如何开发并训练自己的脑力?
  • npm vs npx 终极指南:从原理到实战的深度对比 全面解析包管理器与包执行器的核心差异,助你精准选择工具
  • 零基础深入解析 ngx_http_session_log_module
  • 视频太大?用魔影工厂压缩并转MP4,画质不打折!
  • 【缺陷】GaN和AlN中的掺杂特性
  • 小程序涉及提供提供文本深度合成技术,请补充选择:深度合成-AI问答类目
  • Golang的文件上传与下载
  • C++ 读取英伟达显卡名称、架构及算力
  • 服务器数据恢复—Linux系统服务器崩溃且重装系统的数据恢复案例
  • 常见高速电路设计与信号完整性核心概念
  • ubuntu下docker安装mongodb-支持单副本集
  • XTDrone配置ALOAM三维激光SLAM环境
  • GitLab部署
  • std::chrono类的简单使用实例及分析
  • 传输层协议:UDP和TCP
  • [原创](现代Delphi 12指南):[macOS 64bit App开发]: 如何获取目录大小?
  • 从Cookie到Token:Web开发认证机制演进史(保姆级拆解)
  • 深入解析MATLAB codegen生成MEX文件的原理与优势
  • PostgreSQL初体验
  • 深入解析 HTTP 中的 GET 请求与 POST 请求​