当前位置: 首页 > java >正文

第四章、SKRL(1): Examples

0 前言

官方文档:https://skrl.readthedocs.io/en/latest/intro/examples.html

本节将从官方的案例入手来学习及梳理skrl库。

1 Gymnasium / Gym 环境

torch_gym_cartpole_dqn.py文件为例进行梳理:

整个流程抽象出来其实可以分成以下几步

  • 加载并包装Gym环境:获取了一个强化学习的env环境
  • 将gym环境包装成skrl可用的兼容格式
  • 实例化经验放回缓冲区
  • 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
  • 初始化模型参数(权重和偏置使用正态分布初始化)
  • 实例化DQN代理
  • 配置并实例化序列训练器
  • 启动训练过程

如果你熟悉官方的API接口,会发现这个流程和提供的接口的顺序基本上是一致的
在这里插入图片描述

import gym  # 导入Gym库,用于创建和管理强化学习环境# 导入skrl库中的组件,用于构建强化学习系统
from skrl.agents.torch.dqn import DQN, DQN_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import Shape, deterministic_model# 设置随机种子以保证实验可重复性
set_seed()  # 例如 set_seed(42) 可以设置固定种子# 加载并包装Gym环境(处理不同Gym版本的兼容性问题)
try:env = gym.make("CartPole-v0")  # 尝试加载经典控制问题CartPole环境
except gym.error.DeprecatedEnv as e:# 如果v0版本不可用,自动查找最新可用版本env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0]print("CartPole-v0 not found. Trying {}".format(env_id))env = gym.make(env_id)
env = wrap_env(env)  # 将Gym环境包装为skrl兼容格式device = env.device  # 获取计算设备(CPU/GPU)# 实例化经验回放缓冲区(Experience Replay)
memory = RandomMemory(memory_size=50000,    # 记忆容量num_envs=env.num_envs,  # 并行环境数量device=device,       # 存储设备replacement=False    # 无放回采样
)# 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
models = {}
# 主Q网络(策略网络)
models["q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],      # 两个隐藏层,每层64个神经元"activations": "relu",   # 使用ReLU激活函数}],output="ACTIONS"
)
# 目标Q网络(定期同步的稳定目标)
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS"
)# 初始化模型的延迟模块(Lazy Modules)
for role, model in models.items():model.init_state_dict(role)# 初始化模型参数(权重和偏置使用正态分布初始化)
for model in models.values():model.init_parameters(method_name="normal_", mean=0.0, std=0.1)# 配置DQN代理参数
cfg = DQN_DEFAULT_CONFIG.copy()
cfg["learning_starts"] = 100          # 在开始学习前先收集100步的经验
cfg["exploration"]["final_epsilon"] = 0.04  # 探索率最终值(4%的随机探索)
cfg["exploration"]["timesteps"] = 1500      # 探索率衰减步数
cfg["experiment"]["write_interval"] = 1000  # 每1000步写入TensorBoard日志
cfg["experiment"]["checkpoint_interval"] = 5000  # 每5000步保存模型检查点
cfg["experiment"]["directory"] = "runs/torch/CartPole"  # 实验数据保存路径# 实例化DQN代理
agent = DQN(models=models,memory=memory,cfg=cfg,observation_space=env.observation_space,action_space=env.action_space,device=device
)# 配置并实例化序列训练器
cfg_trainer = {"timesteps": 50000,  # 总训练时间步"headless": True     # 无可视化模式
}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])# 启动训练过程
trainer.train()

运行代码会在你目前所在的文件夹下生成一个runs文件夹,该文件夹下保存了相关的模型文件runs/torch/CartPole/你的运行时间_DQN,打开checkpoints文件夹,会看到其下众多的pt模型文件。

1、RandomMemory部分详情参考:https://blog.csdn.net/m0_47719040/article/details/147978672?sharetype=blogdetail&sharerId=147978672&sharerefer=PC&sharesource=m0_47719040&spm=1011.2480.3001.8118

2、模型实例化部分,创建了一个目标Q网络和主Q网络,并根据model选择Agent

# 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
models = {}
# 主Q网络(策略网络)
models["q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],      # 两个隐藏层,每层64个神经元"activations": "relu",   # 使用ReLU激活函数}],output="ACTIONS"
)
# 目标Q网络(定期同步的稳定目标)
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS"
)
# 实例化DQN代理
agent = DQN(models=models,memory=memory,cfg=cfg,observation_space=env.observation_space,action_space=env.action_space,device=device
)

3、创建训练器并进行训练

http://www.xdnf.cn/news/7315.html

相关文章:

  • Python实例题:Python 实现简易 Shell
  • Python的传参过程的小细节
  • 什么是5G前传、中传、回传?
  • 数据分析—Excel数据清洗函数
  • Compose Kotlin Multiplatform跨平台基础运行
  • CM0启动CM7_0、CM7_1注意事项
  • PCB设计教程【入门篇】——电路分析基础-基本元件(电阻电容电感)
  • Docker 入门指南:从安装配置到核心概念解析
  • [ 计算机网络 ] | 宏观谈谈计算机网络
  • 十三、Hive 行列转换
  • 计算机视觉与深度学习 | Python实现ARIMA-WOA-CNN-LSTM时间序列预测(完整源码和数据
  • netcore项目使用winforms与blazor结合来开发如何按F12,可以调出chrome devtool工具辅助开发
  • 通过低功耗蓝牙通信实例讲透 MCU 各个定时器
  • AT 指令详解:基于 MCU 的通信控制实战指南AT 指令详解
  • ESP32开发-两个WIFI设备的通讯搭建
  • AI大模型从0到1记录学习numpy pandas day25
  • 无人设备遥控器之数据压缩与编码技术篇
  • PLC组网的方法、要点及实施全解析
  • android13以太网静态ip不断断开连上问题
  • C++(24):容器类<list>
  • Unreal 从入门到精通之SceneCaptureComponent2D实现UI层3D物体360°预览
  • MAC常用操作整理
  • MAC电脑中右键后复制和拷贝的区别
  • C++:与7无关的数
  • 基于 Vue 和 Node.js 实现图片上传功能:从前端到后端的完整实践
  • 汽车零部件的EMI抗扰性测试
  • Java中的流详解
  • vue3 vite 路由
  • 容器化-K8s-镜像仓库使用和应用
  • Ubuntu Desktop QEMU/KVM中使用Ubuntu Server 22.04配置k8s集群