深度强化学习框架DI-engine
深度强化学习框架DI-engine
一、DI-engine概述:决策智能的通用引擎
DI-engine是由OpenDILab开源的决策智能引擎,基于PyTorch和JAX构建,旨在为强化学习(RL)、模仿学习(IL)、离线学习等场景提供标准化解决方案。其核心目标是统一不同决策智能环境与应用,支持从学术研究到工业原型的全流程开发。
核心特性
-
算法多样性
支持60+算法,覆盖传统DRL(如DQN、PPO、SAC)、多智能体RL(QMIX、COMA)、模仿学习(GAIL、SQIL)、离线学习(CQL、Decision Transformer)、基于模型的RL(MBPO、DreamerV3)等方向。 -
环境兼容性
内置80+环境,包括Atari、MuJoCo、SMAC(多智能体)、D4RL(离线学习)、Gfootball、Metadrive(自动驾驶)等,支持离散/连续/混合动作空间,兼容单智能体与多智能体场景。 -
系统级优化
- TreeTensor:树状嵌套数据结构,统一数据处理流程,支持自动微分和批量操作。
- 分布式训练:基于DDP和Kubernetes的分布式框架(DI-orchestrator),支持大规模并行训练。
- 中间件机制:模块化设计(如Replay Buffer、Reward Model),支持自定义流程扩展。
-
工程化工具链
提供配置系统(YAML/JSON)、模型加载/恢复、随机种子管理、单元测试框架(pytest)及Docker镜像(含预构建环境)。
二、功能模块与技术架构
1. 算法分类与支持列表
算法类型 | 典型算法举例 | 文档链接 | Demo命令示例 |
---|---|---|---|
基础DRL | DQN、PPO、SAC、TD3 | DQN文档 | ding -m serial -c cartpole_dqn_config.py -s 0 |
多智能体RL | QMIX、WQMIX、MADDPG | QMIX文档 | ding -m serial -c smac_3s5z_qmix_config.py |
模仿学习 | GAIL、SQIL、BCQ | GAIL文档 | ding -m serial_gail -c cartpole_gail_config.py |
离线学习 | CQL、TD3-BC、Decision Transformer | CQL文档 | python3 -u d4rl_cql_main.py |
基于模型的RL | MBPO、DreamerV3、STEVE | MBPO文档 | python3 -u pendulum_mbpo_config.py |
探索机制 | HER、RND、ICM | HER文档 | python3 -u bitflip_her_dqn.py |
2. 环境生态
DI-engine通过dizoo
模块集成丰富环境,部分示例如下:
- 经典控制:CartPole、Pendulum(单智能体离散/连续控制)。
- Atari:Pong、Asterix(视觉密集型任务)。
- 多智能体:SMAC(星际争霸微型战役)、PettingZoo(合作/竞争场景)。
- 离线学习:D4RL(MuJoCo专家轨迹数据集)。
- 真实应用:Metadrive(自动驾驶)、DI-star(星际争霸II决策AI)。
3. 核心组件
- Policy:封装算法逻辑,支持自定义网络结构(如CNN、Transformer)。
- Model:定义神经网络架构,支持PyTorch原生模块扩展。
- Collector/Learner:分离数据采集与模型训练,支持异步并行。
- Replay Buffer:支持优先经验回放(PER)、分层回放(PLR)等高级机制。
- TreeTensor:树状数据结构,示例代码:
import treetensor.torch as ttorch data = ttorch.randn({'obs': (3, 32, 32), 'action': (1,)}) # 嵌套张量 stacked_data = ttorch.stack([data, data], dim=0) # 批量操作
三、快速入门:安装与实战
1. 安装指南
方式1:Pip安装(推荐)
# 稳定版
pip install di-engine# 开发版(含额外依赖)
pip install "di-engine[all]"
方式2:Docker镜像
DI-engine提供预构建镜像,包含常见环境:
# CPU基础镜像
docker pull opendilab/ding:nightly# Atari环境镜像
docker pull opendilab/ding:nightly-atari# 运行示例(以CartPole为例)
docker run -it opendilab/ding:nightly \ding -m serial -c cartpole_dqn_config.py -s 0
2. 入门Demo:DQN玩转CartPole
步骤1:编写配置文件(cartpole_dqn_config.py
)
from ding.config import compile_config
from ding.policy import DQNPolicy
from ding.envs import GymEnvcfg = dict(env=dict(env_id='CartPole-v1',collector_env_num=8,evaluator_env_num=4,),policy=dict(cuda=False,model=dict(obs_shape=4,action_shape=2,),),
)
cfg = compile_config(cfg, seed=0, env=GymEnv, policy=DQNPolicy)
步骤2:运行训练
ding -m serial -c cartpole_dqn_config.py -s 0
步骤3:评估与可视化
from ding.utils import VideoRecorderenv = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')obs = env.reset()
while True:action = policy.predict(obs)obs, reward, done, info = env.step(action)video_recorder.record(obs)if done:break
video_recorder.close()
3. 多智能体示例:QMIX在SMAC环境
# SMAC 3s5z场景(3陆战队vs5 zealots)
ding -m serial -c smac_3s5z_qmix_config.py -s 0
四、进阶功能与最佳实践
1. 自定义环境迁移
若需接入自定义环境,需实现以下接口:
from ding.envs import BaseEnvclass MyEnv(BaseEnv):def __init__(self, env_id):super().__init__()# 初始化环境逻辑def reset(self):# 返回初始观测值return obsdef step(self, action):# 执行动作,返回obs, reward, done, inforeturn obs, reward, done, info@propertydef observation_space(self):# 定义观测空间(gym.Space格式)return Box(low=0, high=255, shape=(3, 32, 32))@propertydef action_space(self):# 定义动作空间return Discrete(10)
2. 模型自定义与网络设计
通过继承ding.model.BaseModel
实现自定义网络:
import torch.nn as nn
from ding.model import BaseModelclass CustomModel(BaseModel):def __init__(self, obs_shape, action_shape):super().__init__()self.cnn = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3),nn.ReLU(),nn.Flatten())self.fc = nn.Linear(32*14*14, action_shape)def forward(self, x):x = self.cnn(x)return self.fc(x)
3. 分布式训练
使用DDP模式启动分布式训练:
ding -m distributed_ddp -c ppo_lunarlander_config.py -n 4
五、生态与社区支持
1. 工具与资源
- DI-zoo:包含30+算法示例与基准环境,地址:https://github.com/opendilab/DI-zoo。
- 教程与文档:
- 快速入门:3分钟上手指南
- 中文文档:ReadTheDocs
- 算法速查表:RL Algorithms Cheat Sheet
- 开源工具链:
- TreeTensor:树状张量库,GitHub
- DI-toolkit:决策智能工具包,PyPI
2. 社区参与
- 问题反馈:在GitHub Issues提交BUG或功能请求。
- 贡献代码:参考CONTRIBUTING.md,参与算法实现或文档完善。
- 交流渠道:
- 微信社群:扫码添加“DI小助手”(见GitHub README)。
- Discord/ Slack:链接。
3. 引用与许可
若在研究中使用DI-engine,请引用:
@misc{ding,title={DI-engine: A Universal AI System/Engine for Decision Intelligence},author={Niu, Yazhe et al.},howpublished={\url{https://github.com/opendilab/DI-engine}},year={2021}
}
DI-engine采用Apache 2.0许可证,允许商业使用与修改。
DI-engine核心算法实战指南
一、DQN(离散动作空间经典算法)
算法特性
- 适用场景:离散动作空间、单智能体、稀疏奖励场景
- 核心思想:基于Q-Learning,使用神经网络近似Q值函数,经验回放+目标网络稳定训练
- DI-engine实现:支持Double DQN、Dueling DQN等变体,集成PER(优先经验回放)
支持环境
环境类别 | 具体环境示例 | 动作空间类型 | 文档链接 |
---|---|---|---|
经典控制 | CartPole-v1、CliffWalking | 离散 | CartPole文档 |
Atari游戏 | Pong、Breakout、Asterix | 离散 | Atari环境指南 |
文本决策 | TabMWP(数学文字题解答) | 离散 | TabMWP文档 |
快速上手:CartPole场景
1. 配置文件(cartpole_dqn_config.py
)
from ding.config import compile_config
from ding.envs import GymEnv
from ding.policy import DQNPolicycfg = dict(env=dict(env_id='CartPole-v1',collector_env_num=8, # 并行采集环境数evaluator_env_num=4, # 并行评估环境数n_evaluator_episode=20, # 评估 episodes 数),policy=dict(cuda=False,model=dict(obs_shape=4, # 观测空间维度action_shape=2, # 动作空间大小(离散)encoder_hidden_size_list=[128, 128], # 网络结构),learn=dict(update_per_collect=10, # 每次采集后更新次数batch_size=32, # 批量大小learning_rate=0.001, # 学习率),collect=dict(n_sample=100, # 每次采集样本数random_collect_size=1000, # 随机初始化样本数),eval=dict(evaluator=dict(eval_freq=500, )) # 评估频率),
)
cfg = compile_config(cfg, env=GymEnv, policy=DQNPolicy, seed=0)
2. 运行训练
# 单机训练
ding -m serial -c cartpole_dqn_config.py -s 0# 分布式训练(4进程)
ding -m distributed_ddp -c cartpole_dqn_config.py -n 4
3. 评估与可视化
from ding.utils import VideoRecorder
env = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('output/ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')obs = env.reset()
while True:action = policy.predict(obs) # 推理动作obs, reward, done, info = env.step(action)video_recorder.record(obs)if done:break
video_recorder.close()
二、PPO(连续/离散通用策略梯度算法)
算法特性
- 适用场景:离散/连续动作空间、单智能体/多智能体(MAPPO变体)
- 核心思想:近端策略优化,通过clip机制平衡策略更新步长
- DI-engine扩展:支持GAE(广义优势估计)、分层回放(PLR)
支持环境
环境类别 | 具体环境示例 | 动作空间类型 | 文档链接 |
---|---|---|---|
经典控制 | LunarLanderContinuous-v2 | 连续 | LunarLander文档 |
MuJoCo | HalfCheetah-v4、Ant-v4 | 连续 | MuJoCo环境指南 |
多智能体 | SMAC(3s5z场景)、PettingZoo | 离散/MARL | SMAC文档 |
实战案例:LunarLander连续控制
1. 配置要点(连续动作特化)
cfg.policy.model = dict(obs_shape=8,action_shape=4,action_space='continuous', # 显式声明连续空间actor_head_hidden_size=256,critic_head_hidden_size=256,
)
cfg.policy.learn = dict(epoch_per_collect=10, # 每个采集周期训练轮数batch_size=2048,clip_ratio=0.2, # PPO核心超参数
)
2. 运行命令
# 连续动作版示例(需指定环境ID)
ding -m serial_onpolicy -c lunarlander_ppo_continuous_config.py -s 0
3. 多智能体扩展(MAPPO)
# SMAC 3s5z多智能体场景
ding -m serial -c smac_3s5z_mappo_config.py -s 0
三、SAC(连续动作空间最大熵强化学习)
算法特性
- 适用场景:连续动作空间、探索性强的环境(如机器人控制)
- 核心思想:结合最大熵原理,同时优化策略熵与累积奖励
- DI-engine优化:支持自动熵调整、多Q网络正则化
支持环境
环境类别 | 具体环境示例 | 动作空间类型 | 文档链接 |
---|---|---|---|
经典控制 | Pendulum-v1 | 连续 | Pendulum文档 |
MuJoCo | Swimmer-v4、Walker2d-v4 | 连续 | MuJoCo环境列表 |
离线学习 | D4RL(HalfCheetah-medium) | 连续(离线) | D4RL文档 |
代码示例:Pendulum摆锤控制
1. 关键配置
cfg.policy.model = dict(obs_shape=3,action_shape=1,twin_critic=True, # 使用双Q网络action_space='continuous',
)
cfg.policy.learn = dict(target_entropy='auto', # 自动计算目标熵discount_factor=0.99,
)
2. 运行与评估
# 训练命令
ding -m serial -c pendulum_sac_config.py -s 0# 离线学习(D4RL数据集)
python3 -u d4rl_sac_main.py --env_id halfcheetah-medium-v0
四、TD3(连续动作空间延迟策略更新算法)
算法特性
- 适用场景:连续动作空间、高维动作空间(如机械臂控制)
- 核心改进:双Q网络+策略延迟更新+动作噪声,缓解过估计问题
- DI-engine集成:支持与HER(indsight experience replay)结合
支持环境
环境类别 | 具体环境示例 | 动作空间类型 | 文档链接 |
---|---|---|---|
经典控制 | MountainCarContinuous-v0 | 连续 | MountainCar文档 |
PyBullet | HumanoidBulletEnv-v0 | 连续 | PyBullet环境指南 |
自定义环境 | 机械臂控制(需迁移) | 连续 | 环境迁移教程 |
快速运行:MountainCarContinuous
# 配置文件路径:ding/example/td3_mountaincar_continuous_config.py
ding -m serial -c td3_mountaincar_continuous_config.py -s 0
高级用法:结合HER探索
# 在配置中启用HER
cfg.policy.other = dict(replay_buffer=dict(type='her',her_type='future', # HER类型(future/episode等))
)
五、环境支持总表(DQN/PPO/SAC/TD3适用场景)
算法 | 离散动作空间 | 连续动作空间 | 多智能体场景 | 离线学习场景 | 典型环境示例 |
---|---|---|---|---|---|
DQN | ✅(全支持) | ❌ | ❌ | ❌ | CartPole、Breakout |
PPO | ✅(离散版) | ✅(连续版) | ✅(MAPPO) | ❌ | LunarLander、SMAC |
SAC | ❌ | ✅(全支持) | ❌ | ✅(D4RL) | Pendulum、D4RL-HalfCheetah |
TD3 | ❌ | ✅(高维优先) | ❌ | ❌ | MountainCarContinuous、Ant |
六、进阶技巧:算法调优与环境适配
1. 离散vs连续动作空间关键差异
- 动作处理:
- 离散:输出logits或Q值,通过
argmax
选择动作 - 连续:输出均值+标准差(如SAC)或直接映射(如TD3),需限制动作范围
- 离散:输出logits或Q值,通过
- 网络结构:
- 离散:单分支输出(动作维度)
- 连续:Actor-Critic双分支(如PPO/SAC)
2. 环境迁移三步法
- 继承BaseEnv:实现
reset()
、step(action)
、observation_space
等接口 - 适配动作格式:
- 离散:动作需为0~n-1的整数(
np.int64
) - 连续:动作需为numpy数组,范围与
action_space
一致
- 离散:动作需为0~n-1的整数(
- 测试兼容性:使用
ding.utils.env_test
验证环境接口
from ding.utils import env_test
env = MyCustomEnv()
env_test(env, coll_num=100) # 测试采集兼容性
3. 超参数调优建议
算法 | 关键超参数 | 调优方向 |
---|---|---|
DQN | learning_rate 、gamma | 小学习率(1e-4~1e-3),gamma=0.99 |
PPO | clip_ratio 、gae_lambda | clip_ratio=0.1~0.3,lambda=0.95 |
SAC | target_entropy 、tau | 自动熵或手动设置(如-env_dim) |
TD3 | policy_noise 、noise_clip | 噪声=0.2,clip=0.5 |
七、总结:选择合适算法的决策树
通过以上指南,可快速在DI-engine中落地主流RL算法。更多细节可参考官方文档:
- 算法文档合集
- 环境迁移教程
- 完整示例代码