当前位置: 首页 > news >正文

深度强化学习框架DI-engine

深度强化学习框架DI-engine

一、DI-engine概述:决策智能的通用引擎

DI-engine是由OpenDILab开源的决策智能引擎,基于PyTorch和JAX构建,旨在为强化学习(RL)、模仿学习(IL)、离线学习等场景提供标准化解决方案。其核心目标是统一不同决策智能环境与应用,支持从学术研究到工业原型的全流程开发。

核心特性

  1. 算法多样性
    支持60+算法,覆盖传统DRL(如DQN、PPO、SAC)、多智能体RL(QMIX、COMA)、模仿学习(GAIL、SQIL)、离线学习(CQL、Decision Transformer)、基于模型的RL(MBPO、DreamerV3)等方向。

  2. 环境兼容性
    内置80+环境,包括Atari、MuJoCo、SMAC(多智能体)、D4RL(离线学习)、Gfootball、Metadrive(自动驾驶)等,支持离散/连续/混合动作空间,兼容单智能体与多智能体场景。

  3. 系统级优化

    • TreeTensor:树状嵌套数据结构,统一数据处理流程,支持自动微分和批量操作。
    • 分布式训练:基于DDP和Kubernetes的分布式框架(DI-orchestrator),支持大规模并行训练。
    • 中间件机制:模块化设计(如Replay Buffer、Reward Model),支持自定义流程扩展。
  4. 工程化工具链
    提供配置系统(YAML/JSON)、模型加载/恢复、随机种子管理、单元测试框架(pytest)及Docker镜像(含预构建环境)。

二、功能模块与技术架构

1. 算法分类与支持列表

算法类型典型算法举例文档链接Demo命令示例
基础DRLDQN、PPO、SAC、TD3DQN文档ding -m serial -c cartpole_dqn_config.py -s 0
多智能体RLQMIX、WQMIX、MADDPGQMIX文档ding -m serial -c smac_3s5z_qmix_config.py
模仿学习GAIL、SQIL、BCQGAIL文档ding -m serial_gail -c cartpole_gail_config.py
离线学习CQL、TD3-BC、Decision TransformerCQL文档python3 -u d4rl_cql_main.py
基于模型的RLMBPO、DreamerV3、STEVEMBPO文档python3 -u pendulum_mbpo_config.py
探索机制HER、RND、ICMHER文档python3 -u bitflip_her_dqn.py

2. 环境生态

DI-engine通过dizoo模块集成丰富环境,部分示例如下:

  • 经典控制:CartPole、Pendulum(单智能体离散/连续控制)。
  • Atari:Pong、Asterix(视觉密集型任务)。
  • 多智能体:SMAC(星际争霸微型战役)、PettingZoo(合作/竞争场景)。
  • 离线学习:D4RL(MuJoCo专家轨迹数据集)。
  • 真实应用:Metadrive(自动驾驶)、DI-star(星际争霸II决策AI)。

3. 核心组件

  • Policy:封装算法逻辑,支持自定义网络结构(如CNN、Transformer)。
  • Model:定义神经网络架构,支持PyTorch原生模块扩展。
  • Collector/Learner:分离数据采集与模型训练,支持异步并行。
  • Replay Buffer:支持优先经验回放(PER)、分层回放(PLR)等高级机制。
  • TreeTensor:树状数据结构,示例代码:
    import treetensor.torch as ttorch
    data = ttorch.randn({'obs': (3, 32, 32), 'action': (1,)})  # 嵌套张量
    stacked_data = ttorch.stack([data, data], dim=0)  # 批量操作
    

三、快速入门:安装与实战

1. 安装指南

方式1:Pip安装(推荐)
# 稳定版
pip install di-engine# 开发版(含额外依赖)
pip install "di-engine[all]"
方式2:Docker镜像

DI-engine提供预构建镜像,包含常见环境:

# CPU基础镜像
docker pull opendilab/ding:nightly# Atari环境镜像
docker pull opendilab/ding:nightly-atari# 运行示例(以CartPole为例)
docker run -it opendilab/ding:nightly \ding -m serial -c cartpole_dqn_config.py -s 0

2. 入门Demo:DQN玩转CartPole

步骤1:编写配置文件(cartpole_dqn_config.py
from ding.config import compile_config
from ding.policy import DQNPolicy
from ding.envs import GymEnvcfg = dict(env=dict(env_id='CartPole-v1',collector_env_num=8,evaluator_env_num=4,),policy=dict(cuda=False,model=dict(obs_shape=4,action_shape=2,),),
)
cfg = compile_config(cfg, seed=0, env=GymEnv, policy=DQNPolicy)
步骤2:运行训练
ding -m serial -c cartpole_dqn_config.py -s 0
步骤3:评估与可视化
from ding.utils import VideoRecorderenv = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')obs = env.reset()
while True:action = policy.predict(obs)obs, reward, done, info = env.step(action)video_recorder.record(obs)if done:break
video_recorder.close()

3. 多智能体示例:QMIX在SMAC环境

# SMAC 3s5z场景(3陆战队vs5 zealots)
ding -m serial -c smac_3s5z_qmix_config.py -s 0

四、进阶功能与最佳实践

1. 自定义环境迁移

若需接入自定义环境,需实现以下接口:

from ding.envs import BaseEnvclass MyEnv(BaseEnv):def __init__(self, env_id):super().__init__()# 初始化环境逻辑def reset(self):# 返回初始观测值return obsdef step(self, action):# 执行动作,返回obs, reward, done, inforeturn obs, reward, done, info@propertydef observation_space(self):# 定义观测空间(gym.Space格式)return Box(low=0, high=255, shape=(3, 32, 32))@propertydef action_space(self):# 定义动作空间return Discrete(10)

2. 模型自定义与网络设计

通过继承ding.model.BaseModel实现自定义网络:

import torch.nn as nn
from ding.model import BaseModelclass CustomModel(BaseModel):def __init__(self, obs_shape, action_shape):super().__init__()self.cnn = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3),nn.ReLU(),nn.Flatten())self.fc = nn.Linear(32*14*14, action_shape)def forward(self, x):x = self.cnn(x)return self.fc(x)

3. 分布式训练

使用DDP模式启动分布式训练:

ding -m distributed_ddp -c ppo_lunarlander_config.py -n 4

五、生态与社区支持

1. 工具与资源

  • DI-zoo:包含30+算法示例与基准环境,地址:https://github.com/opendilab/DI-zoo。
  • 教程与文档
    • 快速入门:3分钟上手指南
    • 中文文档:ReadTheDocs
    • 算法速查表:RL Algorithms Cheat Sheet
  • 开源工具链
    • TreeTensor:树状张量库,GitHub
    • DI-toolkit:决策智能工具包,PyPI

2. 社区参与

  • 问题反馈:在GitHub Issues提交BUG或功能请求。
  • 贡献代码:参考CONTRIBUTING.md,参与算法实现或文档完善。
  • 交流渠道
    • 微信社群:扫码添加“DI小助手”(见GitHub README)。
    • Discord/ Slack:链接。

3. 引用与许可

若在研究中使用DI-engine,请引用:

@misc{ding,title={DI-engine: A Universal AI System/Engine for Decision Intelligence},author={Niu, Yazhe et al.},howpublished={\url{https://github.com/opendilab/DI-engine}},year={2021}
}

DI-engine采用Apache 2.0许可证,允许商业使用与修改。

DI-engine核心算法实战指南

一、DQN(离散动作空间经典算法)

算法特性

  • 适用场景:离散动作空间、单智能体、稀疏奖励场景
  • 核心思想:基于Q-Learning,使用神经网络近似Q值函数,经验回放+目标网络稳定训练
  • DI-engine实现:支持Double DQN、Dueling DQN等变体,集成PER(优先经验回放)

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制CartPole-v1、CliffWalking离散CartPole文档
Atari游戏Pong、Breakout、Asterix离散Atari环境指南
文本决策TabMWP(数学文字题解答)离散TabMWP文档

快速上手:CartPole场景

1. 配置文件(cartpole_dqn_config.py
from ding.config import compile_config
from ding.envs import GymEnv
from ding.policy import DQNPolicycfg = dict(env=dict(env_id='CartPole-v1',collector_env_num=8,  # 并行采集环境数evaluator_env_num=4,   # 并行评估环境数n_evaluator_episode=20,  # 评估 episodes 数),policy=dict(cuda=False,model=dict(obs_shape=4,         # 观测空间维度action_shape=2,      # 动作空间大小(离散)encoder_hidden_size_list=[128, 128],  # 网络结构),learn=dict(update_per_collect=10,  # 每次采集后更新次数batch_size=32,         # 批量大小learning_rate=0.001,   # 学习率),collect=dict(n_sample=100,        # 每次采集样本数random_collect_size=1000,  # 随机初始化样本数),eval=dict(evaluator=dict(eval_freq=500, ))  # 评估频率),
)
cfg = compile_config(cfg, env=GymEnv, policy=DQNPolicy, seed=0)
2. 运行训练
# 单机训练
ding -m serial -c cartpole_dqn_config.py -s 0# 分布式训练(4进程)
ding -m distributed_ddp -c cartpole_dqn_config.py -n 4
3. 评估与可视化
from ding.utils import VideoRecorder
env = GymEnv('CartPole-v1', cfg.env)
policy = DQNPolicy(cfg.policy).load_checkpoint('output/ckpt.pth')
video_recorder = VideoRecorder(env, 'cartpole_demo.mp4')obs = env.reset()
while True:action = policy.predict(obs)  # 推理动作obs, reward, done, info = env.step(action)video_recorder.record(obs)if done:break
video_recorder.close()

二、PPO(连续/离散通用策略梯度算法)

算法特性

  • 适用场景:离散/连续动作空间、单智能体/多智能体(MAPPO变体)
  • 核心思想:近端策略优化,通过clip机制平衡策略更新步长
  • DI-engine扩展:支持GAE(广义优势估计)、分层回放(PLR)

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制LunarLanderContinuous-v2连续LunarLander文档
MuJoCoHalfCheetah-v4、Ant-v4连续MuJoCo环境指南
多智能体SMAC(3s5z场景)、PettingZoo离散/MARLSMAC文档

实战案例:LunarLander连续控制

1. 配置要点(连续动作特化)
cfg.policy.model = dict(obs_shape=8,action_shape=4,action_space='continuous',  # 显式声明连续空间actor_head_hidden_size=256,critic_head_hidden_size=256,
)
cfg.policy.learn = dict(epoch_per_collect=10,  # 每个采集周期训练轮数batch_size=2048,clip_ratio=0.2,        # PPO核心超参数
)
2. 运行命令
# 连续动作版示例(需指定环境ID)
ding -m serial_onpolicy -c lunarlander_ppo_continuous_config.py -s 0
3. 多智能体扩展(MAPPO)
# SMAC 3s5z多智能体场景
ding -m serial -c smac_3s5z_mappo_config.py -s 0

三、SAC(连续动作空间最大熵强化学习)

算法特性

  • 适用场景:连续动作空间、探索性强的环境(如机器人控制)
  • 核心思想:结合最大熵原理,同时优化策略熵与累积奖励
  • DI-engine优化:支持自动熵调整、多Q网络正则化

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制Pendulum-v1连续Pendulum文档
MuJoCoSwimmer-v4、Walker2d-v4连续MuJoCo环境列表
离线学习D4RL(HalfCheetah-medium)连续(离线)D4RL文档

代码示例:Pendulum摆锤控制

1. 关键配置
cfg.policy.model = dict(obs_shape=3,action_shape=1,twin_critic=True,       # 使用双Q网络action_space='continuous',
)
cfg.policy.learn = dict(target_entropy='auto',  # 自动计算目标熵discount_factor=0.99,
)
2. 运行与评估
# 训练命令
ding -m serial -c pendulum_sac_config.py -s 0# 离线学习(D4RL数据集)
python3 -u d4rl_sac_main.py --env_id halfcheetah-medium-v0

四、TD3(连续动作空间延迟策略更新算法)

算法特性

  • 适用场景:连续动作空间、高维动作空间(如机械臂控制)
  • 核心改进:双Q网络+策略延迟更新+动作噪声,缓解过估计问题
  • DI-engine集成:支持与HER(indsight experience replay)结合

支持环境

环境类别具体环境示例动作空间类型文档链接
经典控制MountainCarContinuous-v0连续MountainCar文档
PyBulletHumanoidBulletEnv-v0连续PyBullet环境指南
自定义环境机械臂控制(需迁移)连续环境迁移教程

快速运行:MountainCarContinuous

# 配置文件路径:ding/example/td3_mountaincar_continuous_config.py
ding -m serial -c td3_mountaincar_continuous_config.py -s 0
高级用法:结合HER探索
# 在配置中启用HER
cfg.policy.other = dict(replay_buffer=dict(type='her',her_type='future',  # HER类型(future/episode等))
)

五、环境支持总表(DQN/PPO/SAC/TD3适用场景)

算法离散动作空间连续动作空间多智能体场景离线学习场景典型环境示例
DQN✅(全支持)CartPole、Breakout
PPO✅(离散版)✅(连续版)✅(MAPPO)LunarLander、SMAC
SAC✅(全支持)✅(D4RL)Pendulum、D4RL-HalfCheetah
TD3✅(高维优先)MountainCarContinuous、Ant

六、进阶技巧:算法调优与环境适配

1. 离散vs连续动作空间关键差异

  • 动作处理
    • 离散:输出logits或Q值,通过argmax选择动作
    • 连续:输出均值+标准差(如SAC)或直接映射(如TD3),需限制动作范围
  • 网络结构
    • 离散:单分支输出(动作维度)
    • 连续:Actor-Critic双分支(如PPO/SAC)

2. 环境迁移三步法

  1. 继承BaseEnv:实现reset()step(action)observation_space等接口
  2. 适配动作格式
    • 离散:动作需为0~n-1的整数(np.int64
    • 连续:动作需为numpy数组,范围与action_space一致
  3. 测试兼容性:使用ding.utils.env_test验证环境接口
from ding.utils import env_test
env = MyCustomEnv()
env_test(env, coll_num=100)  # 测试采集兼容性

3. 超参数调优建议

算法关键超参数调优方向
DQNlearning_rategamma小学习率(1e-4~1e-3),gamma=0.99
PPOclip_ratiogae_lambdaclip_ratio=0.1~0.3,lambda=0.95
SACtarget_entropytau自动熵或手动设置(如-env_dim)
TD3policy_noisenoise_clip噪声=0.2,clip=0.5

七、总结:选择合适算法的决策树

离散
连续
需要探索
高维动作
多智能体
离线数据
问题类型
动作空间类型
DQN/PPO
SAC/PPO/TD3
SAC
TD3
MAPPO
CQL/TD3-BC

通过以上指南,可快速在DI-engine中落地主流RL算法。更多细节可参考官方文档:

  • 算法文档合集
  • 环境迁移教程
  • 完整示例代码
http://www.xdnf.cn/news/538723.html

相关文章:

  • Java大师成长计划之第27天:RESTful API设计与实现
  • 算法竞赛 Java 高精度 大数 小数 模版
  • MySQL故障排查域生产环境优化
  • IIR 巴特沃斯II型滤波器设计与实现
  • React Contxt详解
  • 孤立森林和随机森林主要区别
  • Java实现:如何在文件夹中查找重复文件
  • 如何从容应对面试?
  • vi实时查看日志
  • UA 编译和建模入门教程(zhanzhi学习笔记)
  • 基于大模型的脑出血全流程预测系统技术方案大纲
  • 物联网安全技术的最新进展与挑战
  • 深入理解pip:Python包管理的核心工具与实战指南
  • (1-5)Java 常用工具类、包装类、StringStringBuilderString
  • 计算机存储与数据单位的核心定义及换算逻辑
  • 学习黑客 PowerShell 详解
  • 相机Camera日志分析之十五:高通相机Camx 基于预览1帧的ConfigureStreams Usecase完整过程日志分析详解
  • 辅助驾驶平权与出海,Mobileye的双重助力
  • Cursor 模型深度分析:区别、优缺点及适用场景
  • IOS 创建多环境Target,配置多环境
  • GK的作用是什么?
  • C语言指针深入详解(三):数组名理解、指针访问数组、一维数组传参的本质、冒泡排序、二级指针、指针数组、指针数组模拟二维数组
  • opencascade如何保存选中的面到本地
  • 使用MCP驱动IDA pro分析样本
  • DV SSL证书管理主要有哪些功能?
  • C语言—字符函数和字符串函数
  • 如何实现从网页一键启动你的 Electron 桌面应用(zxjapp://)
  • pcie phy电气层(PCS)详解gen1、2 (rx)
  • 北斗卫星通讯终端的技术原理是什么
  • 2025-05-19 学习记录--Python-简易用户登录系统 + 计算天数