第四章、SKRL(1): Examples
0 前言
官方文档:https://skrl.readthedocs.io/en/latest/intro/examples.html
本节将从官方的案例入手来学习及梳理skrl
库。
1 Gymnasium / Gym 环境
以torch_gym_cartpole_dqn.py
文件为例进行梳理:
整个流程抽象出来其实可以分成以下几步
- 加载并包装Gym环境:获取了一个强化学习的env环境
- 将gym环境包装成
skrl
可用的兼容格式 - 实例化经验放回缓冲区
- 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
- 初始化模型参数(权重和偏置使用正态分布初始化)
- 实例化DQN代理
- 配置并实例化序列训练器
- 启动训练过程
如果你熟悉官方的API接口,会发现这个流程和提供的接口的顺序基本上是一致的
import gym # 导入Gym库,用于创建和管理强化学习环境# 导入skrl库中的组件,用于构建强化学习系统
from skrl.agents.torch.dqn import DQN, DQN_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed
from skrl.utils.model_instantiators.torch import Shape, deterministic_model# 设置随机种子以保证实验可重复性
set_seed() # 例如 set_seed(42) 可以设置固定种子# 加载并包装Gym环境(处理不同Gym版本的兼容性问题)
try:env = gym.make("CartPole-v0") # 尝试加载经典控制问题CartPole环境
except gym.error.DeprecatedEnv as e:# 如果v0版本不可用,自动查找最新可用版本env_id = [spec.id for spec in gym.envs.registry.all() if spec.id.startswith("CartPole-v")][0]print("CartPole-v0 not found. Trying {}".format(env_id))env = gym.make(env_id)
env = wrap_env(env) # 将Gym环境包装为skrl兼容格式device = env.device # 获取计算设备(CPU/GPU)# 实例化经验回放缓冲区(Experience Replay)
memory = RandomMemory(memory_size=50000, # 记忆容量num_envs=env.num_envs, # 并行环境数量device=device, # 存储设备replacement=False # 无放回采样
)# 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
models = {}
# 主Q网络(策略网络)
models["q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64], # 两个隐藏层,每层64个神经元"activations": "relu", # 使用ReLU激活函数}],output="ACTIONS"
)
# 目标Q网络(定期同步的稳定目标)
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS"
)# 初始化模型的延迟模块(Lazy Modules)
for role, model in models.items():model.init_state_dict(role)# 初始化模型参数(权重和偏置使用正态分布初始化)
for model in models.values():model.init_parameters(method_name="normal_", mean=0.0, std=0.1)# 配置DQN代理参数
cfg = DQN_DEFAULT_CONFIG.copy()
cfg["learning_starts"] = 100 # 在开始学习前先收集100步的经验
cfg["exploration"]["final_epsilon"] = 0.04 # 探索率最终值(4%的随机探索)
cfg["exploration"]["timesteps"] = 1500 # 探索率衰减步数
cfg["experiment"]["write_interval"] = 1000 # 每1000步写入TensorBoard日志
cfg["experiment"]["checkpoint_interval"] = 5000 # 每5000步保存模型检查点
cfg["experiment"]["directory"] = "runs/torch/CartPole" # 实验数据保存路径# 实例化DQN代理
agent = DQN(models=models,memory=memory,cfg=cfg,observation_space=env.observation_space,action_space=env.action_space,device=device
)# 配置并实例化序列训练器
cfg_trainer = {"timesteps": 50000, # 总训练时间步"headless": True # 无可视化模式
}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])# 启动训练过程
trainer.train()
运行代码会在你目前所在的文件夹下生成一个runs
文件夹,该文件夹下保存了相关的模型文件runs/torch/CartPole/你的运行时间_DQN
,打开checkpoints
文件夹,会看到其下众多的pt模型文件。
1、RandomMemory部分详情参考:https://blog.csdn.net/m0_47719040/article/details/147978672?sharetype=blogdetail&sharerId=147978672&sharerefer=PC&sharesource=m0_47719040&spm=1011.2480.3001.8118
2、模型实例化部分,创建了一个目标Q网络和主Q网络,并根据model
选择Agent
。
# 使用模型实例化工具创建DQN所需模型(Q网络和目标Q网络)
models = {}
# 主Q网络(策略网络)
models["q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64], # 两个隐藏层,每层64个神经元"activations": "relu", # 使用ReLU激活函数}],output="ACTIONS"
)
# 目标Q网络(定期同步的稳定目标)
models["target_q_network"] = deterministic_model(observation_space=env.observation_space,action_space=env.action_space,device=device,clip_actions=False,network=[{"name": "net","input": "STATES","layers": [64, 64],"activations": "relu",}],output="ACTIONS"
)
# 实例化DQN代理
agent = DQN(models=models,memory=memory,cfg=cfg,observation_space=env.observation_space,action_space=env.action_space,device=device
)
3、创建训练器并进行训练