当前位置: 首页 > ops >正文

从代码学习深度强化学习 - PPO PyTorch版

文章目录

  • 前言
  • PPO 算法简介
    • 从 TRPO 到 PPO
    • PPO 的两种形式:惩罚与截断
  • 代码实践:PPO 解决离散动作空间问题 (CartPole)
    • 环境与工具函数
    • 定义策略与价值网络
    • PPO 智能体核心实现
    • 训练与结果
  • 代码实践:PPO 解决连续动作空间问题 (Pendulum)
    • 环境准备
    • 适用于连续动作的网络
    • PPO 智能体 (连续版)
    • 训练与结果
  • 总结


前言

欢迎来到深度强化学习(DRL)的世界!在众多 DRL 算法中,Proximal Policy Optimization (PPO) 无疑是最受欢迎和广泛应用的算法之一。它由 OpenAI 在 2017 年提出,以其出色的性能、相对简单的实现和稳定的训练过程而著称,成为了许多研究和应用的基准算法。

本篇博客旨在通过一个完整的 PyTorch 实现,带您从代码层面深入理解 PPO 算法。我们将不仅仅是看公式,更是要“动手”,一步步构建、训练和分析 PPO 智能体。为了全面掌握其应用,我们将分别在经典的离散动作空间(CartPole-v1)和连续动作空间(Pendulum-v1)两个环境中进行实践。

无论您是 DRL 的初学者,还是希望巩固 PPO 知识的实践者,相信通过这篇代码驱动的教程,您都能对 PPO 有一个更具体、更深刻的认识。

完整代码:下载链接


PPO 算法简介

在深入代码之前,我们先快速回顾一下 PPO 的核心思想。

从 TRPO 到 PPO

PPO 的思想源于 TRPO(Trust Region Policy Optimization)。TRPO 旨在通过限制每次策略更新的步长,确保更新后的策略不会与旧策略偏离太远,从而保证学习的稳定性。它的优化目标如下:

TRPO 通过一个 KL 散度的约束来限制策略更新的区域,但这个约束的计算过程非常复杂,涉及泰勒展开、共轭梯度、线性搜索等,导致其实现难度大,运算量也非常可观。

PPO 的出现正是为了解决这个问题。它继承了 TRPO 的核心思想,即在更新策略时不要“步子迈得太大”,但采用了更简单、更易于实现的方法。

PPO 的两种形式:惩罚与截断

PPO 主要有两种形式:PPO-PenaltyPPO-Clip

  1. PPO-Penalty (惩罚)
    它将 TRPO 的 KL 散度约束作为一个惩罚项直接放入目标函数中,变成一个无约束的优化问题,并通过一个动态调整的系数 β 来控制惩罚的力度。

  2. PPO-Clip (截断)
    这是更常用的一种形式,也是我们代码将要实现的版本。它直接在目标函数中进行截断(clip),以保证新的参数和旧的参数的差距不会太大。

    其核心思想在于 clip 函数。我们定义一个比率 r(θ) 为新策略与旧策略输出同一动作的概率之比。

    • 优势函数 A > 0 时(即当前动作优于平均水平),我们希望增大这个动作的概率,但 r(θ) 的上限被截断在 1+ε,防止策略更新过于激进。
    • 优势函数 A < 0 时(即当前动作劣于平均水平),我们希望减小这个动作的概率,但 r(θ) 的下限被截断在 1-ε,同样是为了限制更新幅度。

    下图直观地展示了 PPO-Clip 的目标函数 L^Clip 与概率比 r(θ) 的关系:

大量的实验表明,PPO-Clip 的性能通常比 PPO-Penalty 更好且更稳定。因此,我们的代码实践将专注于 PPO-Clip 的实现。

理论铺垫结束,让我们开始编码吧!

代码实践:PPO 解决离散动作空间问题 (CartPole)

我们将从经典的 CartPole-v1 环境开始,它要求智能体通过向左或向右施加力来保持杆子竖直不倒,是一个典型的离散动作空间问题(动作:0-向左,1-向右)。

环境与工具函数

首先,我们定义一些通用的工具函数并初始化环境。这里的核心是 compute_advantage 函数,它实现了广义优势估计(GAE),这是一种在偏差和方差之间取得平衡的优势函数计算方法,对于稳定策略梯度算法的训练至关重要。

PPO离散动作.ipynb

"""
强化学习工具函数集
包含广义优势估计(GAE)和数据平滑处理功能
"""import torch
import numpy as npdef compute_advantage(gamma, lmbda, td_delta):"""计算广义优势估计(Generalized Advantage Estimation,GAE)GAE是一种在强化学习中用于减少策略梯度方差的技术,通过对时序差分误差进行指数加权平均来估计优势函数,平衡偏差和方差的权衡。参数:gamma (float): 折扣因子,维度: 标量取值范围[0,1],决定未来奖励的重要性lmbda (float): GAE参数,维度: 标量  取值范围[0,1],控制偏差-方差权衡lmbda=0时为TD(0)单步时间差分,lmbda=1时为蒙特卡洛方法用采样到的奖励-状态价值估计td_delta (torch.Tensor): 时序差分误差序列,维度: [时间步数]包含每个时间步的TD误差值返回:torch.Tensor: 广义优势估计值,维度: [时间步数]与输入td_delta维度相同的优势函数估计数学公式:A_t^GAE(γ,λ) = Σ_{l=0}^∞ (γλ)^l * δ_{t+l}其中 δ_t = r_t + γV(s_{t+1}) - V(s_t) 是TD误差"""# 将PyTorch张量转换为NumPy数组进行计算# td_delta维度: [时间步数] -> [时间步数]td_delta = td_delta.detach().numpy() # 因为A用来求g的,需要梯度,防止梯度向下传播# 初始化优势值列表,用于存储每个时间步的优势估计# advantage_list维度: 最终为[时间步数]advantage_list = []# 初始化当前优势值,从序列末尾开始反向计算# advantage维度: 标量advantage = 0.0# 从时间序列末尾开始反向遍历TD误差# 反向计算是因为GAE需要利用未来的信息# delta维度: 标量(td_delta中的单个元素)for delta in td_delta[::-1]:  # [::-1]实现序列反转# GAE递归公式:A_t = δ_t + γλA_{t+1}# gamma * lmbda * advantage: 来自未来时间步的衰减优势值# delta: 当前时间步的TD误差# advantage维度: 标量advantage = gamma * lmbda * advantage + delta# 将计算得到的优势值添加到列表中# advantage_list维度: 逐步增长到[时间步数]advantage_list.append(advantage)# 由于是反向计算,需要将结果列表反转回正确的时间顺序# advantage_list维度: [时间步数](时间顺序已恢复)advantage_list.reverse()# 将NumPy列表转换回PyTorch张量并返回# 返回值维度: [时间步数]return torch.tensor(advantage_list, dtype=torch.float)def moving_average(data, window_size):"""计算移动平均值,用于平滑奖励曲线该函数通过滑动窗口的方式对时间序列数据进行平滑处理,可以有效减少数据中的噪声,使曲线更加平滑美观。常用于强化学习中对训练过程的奖励曲线进行可视化优化。参数:data (list): 原始数据序列,维度: [num_episodes]包含需要平滑处理的数值数据(如每轮训练的奖励值)window_size (int): 移动窗口大小,维度: 标量决定了平滑程度,窗口越大平滑效果越明显但也会导致更多的数据点丢失返回:list: 移动平均后的数据,维度: [len(data) - window_size + 1]返回的数据长度会比原数据少 window_size - 1 个元素这是因为需要足够的数据点来计算第一个移动平均值示例:>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 维度: [10]>>> smoothed = moving_average(data, 3)       # window_size = 3>>> print(smoothed)  # 输出: [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]  维度: [8]"""# 边界检查:如果数据长度小于窗口大小,直接返回原数据# 这种情况下无法计算移动平均值# data维度: [num_episodes], window_size维度: 标量if len(data) < window_size:return data# 初始化移动平均值列表# moving_avg维度: 最终为[len(data) - window_size + 1]moving_avg = []# 遍历数据,计算每个窗口的移动平均值# i的取值范围: 0 到 len(data) - window_size# 循环次数: len(data) - window_size + 1# 每次循环处理一个滑动窗口位置for i in range(len(data) - window_size + 1):# 提取当前窗口内的数据切片# window_data维度: [window_size]# 包含从索引i开始的连续window_size个元素# 例如:当i=0, window_size=3时,提取data[0:3]window_data = data[i:i + window_size]# 计算当前窗口内数据的算术平均值# np.mean(window_data)维度: 标量# 将平均值添加到结果列表中moving_avg.append(np.mean(window_data))# 返回移动平均后的数据列表# moving_avg维度: [len(data) - window_size + 1]return moving_avg``````python
"""
强化学习环境初始化模块
用于创建和配置OpenAI Gym环境
"""import gym# 环境配置
# 定义要使用的强化学习环境名称
# CartPole-v1是经典的平衡杆控制问题:
# - 状态空间:4维连续空间(车位置、车速度、杆角度、杆角速度)
# - 动作空间:2维离散空间(向左推车、向右推车)
# - 目标:保持杆子平衡尽可能长的时间
# env_name维度: 标量(字符串)
env_name = 'CartPole-v1'# 创建强化学习环境实例
# gym.make()函数根据环境名称创建对应的环境对象
# 该环境对象包含了状态空间、动作空间、奖励函数等定义
# env维度: gym.Env对象(包含状态空间[4]和动作空间[2]的环境实例)
# env.observation_space.shape: (4,) - 观测状态维度
# env.action_space.n: 2 - 离散动作数量
env = gym.make(env_name)

定义策略与价值网络

PPO 是一种 Actor-Critic 架构的算法。我们需要定义两个网络:

  • 策略网络 (PolicyNet):作为 Actor,输入状态,输出一个动作的概率分布。
  • 价值网络 (ValueNet):作为 Critic,输入状态,输出该状态的价值估计 V(s)。
"""
PPO(Proximal Policy Optimization)算法实现
包含策略网络、价值网络和PPO智能体的完整定义
"""import torch
import torch.nn.functional as F
import numpy as npclass PolicyNet(torch.nn.Module):"""策略网络(Actor Network)用于输出动作概率分布,指导智能体如何选择动作"""def __init__(self, state_dim, hidden_dim, action_dim):"""初始化策略网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力action_dim (int): 动作空间维度,维度: 标量对于CartPole-v1环境,action_dim=2"""super(PolicyNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 动作概率# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, action_dim]self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 动作概率分布,维度: [batch_size, action_dim]每行为一个状态对应的动作概率分布,概率和为1"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层 + Softmax激活函数,输出概率分布# x维度: [batch_size, hidden_dim] -> [batch_size, action_dim]# dim=1表示在第1维(动作维度)上进行softmax,确保每行概率和为1return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):"""价值网络(Critic Network)用于估计状态价值函数V(s),评估当前状态的好坏"""def __init__(self, state_dim, hidden_dim):"""初始化价值网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力"""super(ValueNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 状态价值(标量)# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, 1]self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 状态价值估计,维度: [batch_size, 1]每行为一个状态对应的价值估计"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层,输出状态价值(无激活函数,可以输出负值)# x维度: [batch_size, hidden_dim] -> [batch_size, 1]return self.fc2(x)

PPO 智能体核心实现

这是我们 PPO 算法的核心。PPO 类封装了 Actor 和 Critic,并实现了 take_action(动作选择)和 update(网络更新)两个关键方法。请特别关注 update 函数,它完整地实现了 PPO-Clip 的目标函数计算和参数更新逻辑。

class PPO:"""PPO(Proximal Policy Optimization)算法实现采用截断方式防止策略更新过大,确保训练稳定性"""def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):"""初始化PPO智能体参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层神经元数量,维度: 标量action_dim (int): 动作空间维度,维度: 标量actor_lr (float): Actor网络学习率,维度: 标量critic_lr (float): Critic网络学习率,维度: 标量lmbda (float): GAE参数λ,维度: 标量,取值范围[0,1]epochs (int): 每次更新的训练轮数,维度: 标量eps (float): PPO截断参数ε,维度: 标量,通常取0.1-0.3gamma (float): 折扣因子γ,维度: 标量,取值范围[0,1]device (torch.device): 计算设备(CPU或GPU),维度: 标量"""# 初始化Actor网络(策略网络)# 网络参数维度:fc1权重[state_dim, hidden_dim], fc2权重[hidden_dim, action_dim]self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)#
http://www.xdnf.cn/news/15140.html

相关文章:

  • Go语言WebSocket编程:从零打造实时通信利器
  • Linux操作系统从入门到实战:怎么查看,删除,更新本地的软件镜像源
  • 蔚来测开一面:HashMap从1.7开始到1.8的过程,既然都解决不了并发安全问题,为什么还要进一步解决环形链表的问题?
  • Spring的事务控制——学习历程
  • HarmonyOS NEXT端云一体化开发初体验
  • [Dify] -基础入门4-快速创建你的第一个 Chat 应用
  • 牛客:HJ17 坐标移动[华为机考][字符串]
  • Leaflet面试题及答案(1-20)
  • [实战]调频三角波和锯齿波信号生成(完整C代码)
  • 深入浅出:什么是MCP(模型上下文协议)?
  • 力扣网编程134题:加油站(双指针)
  • C++中柔性数组的现代化替代方案:从内存布局优化到标准演进
  • Debian:从GNOME切换到Xfce
  • 扫描文件 PDF / 图片 纠斜 | 图片去黑边 / 裁剪 / 压缩
  • I2C集成电路总线
  • Semi-Supervised Single-View 3D Reconstruction via Prototype Shape Priors
  • 基于Java Spring Boot开发的旅游景区智能管理系统 计算机毕业设计源码32487
  • linux网络编程之单reactor模型(一)
  • Python 数据建模与分析项目实战预备 Day 2 - 数据构建与字段解析(模拟简历结构化数据)
  • 【前端】【组件库开发】【原理】【无框架开发】现代网页弹窗开发指南:从基础到优化
  • GNhao,获取跨境手机SIM卡跨境通信新选择!
  • 手机恢复出厂设置怎么找回数据?Aiseesoft FoneLab for Android数据恢复工具分享
  • Java中的泛型继承
  • 深度学习篇---昇腾NPUCANN 工具包
  • 《Java EE与中间件》实验三 基于Spring Boot框架的购物车
  • BLOB 数据的插入与读取详解
  • Linux驱动学习day22(interrupt子系统)
  • [python]在drf中使用drf_spectacular
  • 卢比危机下的金融破局:科伦坡交易所技术升级作战图
  • SpringBoot JWT