当前位置: 首页 > news >正文

stable-baseline3介绍

🤖 Stable-Baselines3 高级使用指南:模型创建、结构修改与超参数优化

本教程假设你已经安装好 Stable-Baselines3,并熟悉基本的 RL 概念(Agent、Environment、Reward)。
内容涵盖:

  • 模型创建与训练
  • 修改神经网络结构
  • 调整超参数
  • 保存与加载模型
  • 对比训练效果

一、创建 RL 模型

在 Stable-Baselines3 中,创建模型非常简单。以 PPO 为例:

from stable_baselines3 import PPO# 假设 env 已经创建好了
model = PPO("MlpPolicy", env, verbose=1)
  • "MlpPolicy" 表示使用全连接多层感知机(MLP)策略
  • env 是 Gym 或 Gymnasium 环境
  • verbose=1 会输出训练日志

Tip:其他算法类似,如 DQN、A2C、SAC、TD3。


二、修改神经网络结构

Stable-Baselines3 提供 policy_kwargs 参数来定义 Actor 和 Critic 网络结构。

1. 修改网络层数和单元数

policy_kwargs = dict(net_arch=[dict(pi=[128, 128], vf=[128, 128])]  # Actor 和 Critic 两层128单元
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
  • pi:Actor 网络(输出动作概率/策略)
  • vf:Critic 网络(输出状态值)
  • 可以自由修改层数或单元数,例如 [64,64,64]

2. 使用共享网络

policy_kwargs = dict(net_arch=[256, 256]  # Actor和Critic共享网络
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

3. 自定义网络模块

from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractorclass CustomMLP(BaseFeaturesExtractor):def __init__(self, observation_space, features_dim=128):super().__init__(observation_space, features_dim)self.net = nn.Sequential(nn.Linear(observation_space.shape[0], 128),nn.ReLU(),nn.Linear(128, features_dim),nn.ReLU())def forward(self, x):return self.net(x)policy_kwargs = dict(features_extractor_class=CustomMLP,features_extractor_kwargs=dict(features_dim=128)
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

Tip:适合需要完全自定义特征提取或复杂输入的环境。


三、调整超参数

Stable-Baselines3 提供丰富的超参数可调:

参数作用示例
learning_rate学习率0.0003, 0.001
n_steps每轮采样步数128, 512
batch_size更新批次大小64, 256
gamma折扣因子0.95, 0.99
ent_coef熵系数(鼓励探索)0.0, 0.01
clip_rangePPO 裁剪范围0.2

示例:调整 PPO 超参数

model = PPO("MlpPolicy",env,learning_rate=0.001,n_steps=512,batch_size=128,gamma=0.99,ent_coef=0.01,verbose=1
)

Tip:先固定大部分参数,逐步调整一个或两个关键参数(如学习率和网络大小),更容易观察效果变化。


四、训练与保存模型

# 训练
model.learn(total_timesteps=50000)# 保存
model.save("ppo_custom_model")# 加载
model = PPO.load("ppo_custom_model", env=env)

Tip:保存时只需指定模型名称,加载时指定环境即可继续训练。


五、策略测试与对比训练效果

1. 测试策略

obs = env.reset()
for _ in range(1000):action, _ = model.predict(obs)obs, reward, done, info = env.step(action)env.render()if done:obs = env.reset()

2. 记录奖励并绘制曲线

from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt
import pandas as pdenv = Monitor(env, "logs/")  # 保存训练日志# 训练
model.learn(total_timesteps=50000)# 读取日志
data = pd.read_csv("logs/monitor.csv", skiprows=1)
plt.plot(data['l'].rolling(50).mean())  # 滑动平均奖励
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training Reward Curve")
plt.show()

3. 对比多组实验

  • 使用不同网络规模、不同学习率、不同 n_steps 训练多个模型
  • 绘制多条曲线在同一张图中,直观比较训练速度、稳定性和最终平均奖励
plt.plot(reward_curve_model1, label="Small Net")
plt.plot(reward_curve_model2, label="Large Net")
plt.plot(reward_curve_model3, label="High LR")
plt.legend()
plt.show()

六、实战建议

  1. 网络结构

    • 小环境(CartPole、MountainCar):64~128 单元即可
    • 复杂环境(Atari、MuJoCo):128~256 单元,多层网络
  2. 超参数调整

    • 学习率(learning_rate)先调
    • 然后调整采样步数 n_steps
    • 折扣因子 gamma、熵系数 ent_coef 可微调
  3. 实验对比

    • 每次只修改 1~2 个参数
    • 固定随机种子确保可比性
    • 保存每次模型和日志,绘制奖励曲线

✅ 总结:

  • Stable-Baselines3 提供灵活的 模型创建、网络结构修改和超参数调整
  • 熟练掌握 policy_kwargslearning_raten_steps 等参数,可快速优化训练效果
  • 通过日志和奖励曲线对比,不断迭代实验,找到最佳配置
http://www.xdnf.cn/news/1382707.html

相关文章:

  • 个人博客运行3个月记录
  • mac m4执行nvm install 14.19.1报错,安装低版本node报错解决
  • 【STM32】G030单片机的窗口看门狗
  • Flutter:ios打包ipa,证书申请,Xcode打包,完整流程
  • LeetCode Hot 100 第7天
  • mac系统本地部署Dify步骤梳理
  • 仓颉编程语言青少年基础教程:输入输出
  • 模拟实现Linux中的进度条
  • [Mysql数据库] 知识点总结5
  • 天津医科大学肿瘤医院冷热源群控系统调试完成:以 “精准控温 + 高效节能” 守护医疗核心场景
  • 实战演练(一):从零构建一个功能完备的Todo List应用
  • Spring事务管理机制深度解析:从JDBC基础到Spring高级实现
  • 力扣(LeetCode) ——965. 单值二叉树(C语言)
  • C#写的一键自动测灯带的应用 AI帮写的。
  • [灵动微电子 MM32BIN560CN MM32SPIN0280]读懂电机MCU之串口DMA
  • list 手动实现 1
  • 学习日志40 python
  • 微服务即时通信系统(十三)--- 项目部署
  • 【后端】微服务后端鉴权方案
  • 虚函数指针和虚函数表的创建时机和存放位置
  • 【Linux知识】Linux 设置账号密码永不过期
  • 完整代码注释:实现 Qt 的 TCP 客户端,实现和服务器通信
  • 【LINUX网络】TCP原理
  • WEEX唯客上线C2C交易平台:打造安全便捷的用户交易体验
  • 现在购买PCIe 5.0 SSD是否是最好的时机?
  • 前端实现Linux查询平台:打造高效运维工作流
  • [光学原理与应用-320]:光学产品不同阶段使用的工具软件、对应的输出文件
  • 华为S5720S重置密码
  • c语言动态数组扩容
  • MCU平台化实践方案