stable-baseline3介绍
🤖 Stable-Baselines3 高级使用指南:模型创建、结构修改与超参数优化
本教程假设你已经安装好 Stable-Baselines3,并熟悉基本的 RL 概念(Agent、Environment、Reward)。
内容涵盖:
- 模型创建与训练
- 修改神经网络结构
- 调整超参数
- 保存与加载模型
- 对比训练效果
一、创建 RL 模型
在 Stable-Baselines3 中,创建模型非常简单。以 PPO 为例:
from stable_baselines3 import PPO# 假设 env 已经创建好了
model = PPO("MlpPolicy", env, verbose=1)
"MlpPolicy"
表示使用全连接多层感知机(MLP)策略env
是 Gym 或 Gymnasium 环境verbose=1
会输出训练日志
Tip:其他算法类似,如 DQN、A2C、SAC、TD3。
二、修改神经网络结构
Stable-Baselines3 提供 policy_kwargs
参数来定义 Actor 和 Critic 网络结构。
1. 修改网络层数和单元数
policy_kwargs = dict(net_arch=[dict(pi=[128, 128], vf=[128, 128])] # Actor 和 Critic 两层128单元
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
pi
:Actor 网络(输出动作概率/策略)vf
:Critic 网络(输出状态值)- 可以自由修改层数或单元数,例如
[64,64,64]
2. 使用共享网络
policy_kwargs = dict(net_arch=[256, 256] # Actor和Critic共享网络
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
3. 自定义网络模块
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractorclass CustomMLP(BaseFeaturesExtractor):def __init__(self, observation_space, features_dim=128):super().__init__(observation_space, features_dim)self.net = nn.Sequential(nn.Linear(observation_space.shape[0], 128),nn.ReLU(),nn.Linear(128, features_dim),nn.ReLU())def forward(self, x):return self.net(x)policy_kwargs = dict(features_extractor_class=CustomMLP,features_extractor_kwargs=dict(features_dim=128)
)model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
Tip:适合需要完全自定义特征提取或复杂输入的环境。
三、调整超参数
Stable-Baselines3 提供丰富的超参数可调:
参数 | 作用 | 示例 |
---|---|---|
learning_rate | 学习率 | 0.0003, 0.001 |
n_steps | 每轮采样步数 | 128, 512 |
batch_size | 更新批次大小 | 64, 256 |
gamma | 折扣因子 | 0.95, 0.99 |
ent_coef | 熵系数(鼓励探索) | 0.0, 0.01 |
clip_range | PPO 裁剪范围 | 0.2 |
示例:调整 PPO 超参数
model = PPO("MlpPolicy",env,learning_rate=0.001,n_steps=512,batch_size=128,gamma=0.99,ent_coef=0.01,verbose=1
)
Tip:先固定大部分参数,逐步调整一个或两个关键参数(如学习率和网络大小),更容易观察效果变化。
四、训练与保存模型
# 训练
model.learn(total_timesteps=50000)# 保存
model.save("ppo_custom_model")# 加载
model = PPO.load("ppo_custom_model", env=env)
Tip:保存时只需指定模型名称,加载时指定环境即可继续训练。
五、策略测试与对比训练效果
1. 测试策略
obs = env.reset()
for _ in range(1000):action, _ = model.predict(obs)obs, reward, done, info = env.step(action)env.render()if done:obs = env.reset()
2. 记录奖励并绘制曲线
from stable_baselines3.common.monitor import Monitor
import matplotlib.pyplot as plt
import pandas as pdenv = Monitor(env, "logs/") # 保存训练日志# 训练
model.learn(total_timesteps=50000)# 读取日志
data = pd.read_csv("logs/monitor.csv", skiprows=1)
plt.plot(data['l'].rolling(50).mean()) # 滑动平均奖励
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training Reward Curve")
plt.show()
3. 对比多组实验
- 使用不同网络规模、不同学习率、不同 n_steps 训练多个模型
- 绘制多条曲线在同一张图中,直观比较训练速度、稳定性和最终平均奖励
plt.plot(reward_curve_model1, label="Small Net")
plt.plot(reward_curve_model2, label="Large Net")
plt.plot(reward_curve_model3, label="High LR")
plt.legend()
plt.show()
六、实战建议
-
网络结构
- 小环境(CartPole、MountainCar):64~128 单元即可
- 复杂环境(Atari、MuJoCo):128~256 单元,多层网络
-
超参数调整
- 学习率(learning_rate)先调
- 然后调整采样步数 n_steps
- 折扣因子 gamma、熵系数 ent_coef 可微调
-
实验对比
- 每次只修改 1~2 个参数
- 固定随机种子确保可比性
- 保存每次模型和日志,绘制奖励曲线
✅ 总结:
- Stable-Baselines3 提供灵活的 模型创建、网络结构修改和超参数调整
- 熟练掌握
policy_kwargs
、learning_rate
、n_steps
等参数,可快速优化训练效果 - 通过日志和奖励曲线对比,不断迭代实验,找到最佳配置