当前位置: 首页 > ds >正文

强化学习入门:从零开始实现DDQN

在上一篇文章《深度强化学习入门:从零开始实现DQN》中,我们已经完整介绍了CartPole环境、DQN的理论背景以及实现流程。本篇文章将在此基础上,进一步介绍DQN的缺点,并通过双重深度Q网络(Double DQN, 简称DDQN) 的来解决这些问题,最终训练出一个更加稳定和可靠的CartPole智能体。


DQN的缺陷:过估计问题(Overestimation Bias)

在DQN中,目标值的计算公式为:

y=r+γmax⁡a′Qtarget(s′,a′;wT)y = r + \gamma \max_{a'} Q_{\text{target}}(s', a'; w_T) y=r+γamaxQtarget(s,a;wT)

这里的max操作会直接选择使得Q值最大的动作,然后用目标网络计算其Q值。这带来一个问题:

  • 神经网络预测的Q值本身存在估计误差
  • 由于max操作总是会选择估计值最大的动作,因此会倾向于选择被高估的动作
  • 长期迭代后,这种系统性偏差会不断累积,导致Q值普遍被高估,训练不稳定。

DDQN的核心改进:解耦动作选择与评估

DDQN的提出正是为了解决过估计问题。其核心思想是:

  • 动作选择(Action Selection)主网络完成。
  • 动作评估(Action Evaluation)目标网络完成。

具体公式为:

y=r+γQtarget(s′,arg⁡max⁡a′Qmain(s′,a′;w),wT)y = r + \gamma Q_{\text{target}}(s', \arg\max_{a'} Q_{\text{main}}(s', a'; w), w_T) y=r+γQtarget(s,argamaxQmain(s,a;w),wT)

相比DQN的公式,区别在于:

  1. 我们先用主网络找到下一个状态s'下Q值最大的动作:

    a∗=arg⁡max⁡a′Qmain(s′,a′;w)a^* = \arg\max_{a'} Q_{\text{main}}(s', a'; w) a=argamaxQmain(s,a;w)

  2. 然后用目标网络来评估这个动作:

    Qtarget(s′,a∗;wT)Q_{\text{target}}(s', a^*; w_T) Qtarget(s,a;wT)

这样一来,“选动作”和“算值”分开进行,可以减轻过估计问题,从而让训练更稳定。


DDQN代码实现

在上一篇文章中,我们已经给出了DQN的代码结构。实现DDQN的改动非常小,主要在计算TD目标的部分。下面我们只展示和DQN不同的核心代码

修改点:Algorithm类中的learn方法

def learn(self, experiences):"""更新主网络"""# 将经验样本转化为tensor类型states, actions, next_states, rewards, dones = Processor.convert_tensors(experiences)# 根据主网络得到预测Q值current_q_values = self.model(states).gather(dim=1, index=actions.unsqueeze(-1))# ==============================DQN 更新target_q_values============================# 根据target_network得到目标Q值# with torch.no_grad():#     values = self.target_model(next_states).max(1)[0].detach()#     target_q_values = rewards + (1-dones) * (Config.GAMMA * values)# =============================DDQN 更新target_q_values=============================# 根据target_network得到目标Q值with torch.no_grad():next_actions = self.model(next_states).argmax(1)values = self.target_model(next_states).gather(dim=1, index=next_actions.unsqueeze(-1)).squeeze(dim=1)target_q_values = rewards + (1 - dones) * Config.GAMMA * values# 计算lossloss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))# 反向传播loss.backward()# 梯度更新self.optimizer.step()# 梯度清0self.optimizer.zero_grad()# 向监视器添加梯度信息self.monitor.add_loss_info(loss.detach().item())self.train_count += 1# 更新target_networkif self.train_count % Config.TARGET_UPDATE_INTERVAL == 0:self.update_target_network()

其他部分保持不变

  • 模型结构(Model类):仍然是多层感知机(MLP),输入状态输出动作的Q值。
  • 经验回放、训练流程(train_workflow):和DQN一样。
  • 配置文件与环境管理:也完全一致。

完整代码


结果

训练一千次回合,Loss以及回答长度随训练次数的变化趋势如下图所示:
在这里插入图片描述

用训练好的智能体测试一百次,100次episode的回合长度如下图所示:
在这里插入图片描述

http://www.xdnf.cn/news/20538.html

相关文章:

  • Ai8051 2.4寸320*240 ILI9341 I8080接口驱动
  • 人工智能学习:基于seq2seq模型架构实现翻译
  • 项目初始化上传git
  • Qemu-NUC980(四):SDRAM Interface Controller
  • 什么是“二合一矫平机”?——一篇技术科普
  • 主流的开源协议(MIT,Apache,GPL v2/v3)
  • Qt编程之信号与槽
  • 吴恩达机器学习(八)
  • make时设置链接器选项的2种方法
  • 【操作系统-Day 25】死锁 (Deadlock):揭秘多线程编程的“终极杀手”
  • Zoom AI 技术架构研究:联合式方法与多模态集成
  • 【LeetCode热题100道笔记】翻转二叉树
  • python炒股
  • C++ 20 新增特性以及代码示例
  • 同态加密库(Google FHE)
  • 神经网络的初始化:权重与偏置的数学策略
  • C# WinForm分页控件实现与使用详解
  • B.50.10.09-RPC核心原理与电商应用
  • MATLAB R2025a安装配置及使用教程(超详细保姆级教程)
  • 什么是云手机?
  • Vue3 - Echarts自定义主题引入(Error:ECharts is not Loaded,Error:default ,Error:module)
  • 攻击服务器的方式有哪些,对应的应对策略有哪些?
  • 联邦学习论文分享:Towards Building the Federated GPT:Federated Instruction Tuning
  • Leetcode hot100 最长连续序列
  • rh134第五章复习总结
  • SDRAM详细分析-08 数据手册解读
  • AI + 办公工具 = 应用案例
  • (论文速读)视觉语言模型评价中具有挑战性的选择题的自动生成
  • 大模型推理时的加速思路?
  • RabbitMq 初步认识