当前位置: 首页 > ai >正文

强化学习:高级策略梯度理论与优化方法

如果您想学习强化学习,我推荐David Sliver的讲座😊:RL Course by David Silver - Lecture 1: Introduction to Reinforcement Learning - YouTube

在本文开始前,如果您还没读过我的前一篇文章,由此进:

强化学习:基础理论与高级DQN算法及策略梯度基础-CSDN博客

自然策略梯度(NPG)与信息几何

1.策略空间的黎曼流式结构

  • 策略分布族:将策略参数空间视为统计流形 \mathcal{M}={\pi_\theta|\theta \in \Theta}

  • Fisher信息矩阵(黎曼度量张量):

F(\theta)=\mathbb{E}_{s\sim p^\pi,a\sim \pi_\theta}[\nabla_\theta\log \pi_\theta(a|s)\nabla_\theta\log\pi_\theta(a|s)^\top]

  • KL散度的局部近似(泰勒展开到二阶):

\text{KL}(\pi_\theta||\pi_{\theta+\Delta_\theta})\approx \frac{1}{2}\Delta\theta^\top F(\theta)\Delta\theta

2.自然梯度定义

  • 传统梯度方向在欧氏空间,自然梯度在黎曼空间:

\tilde{\nabla}_\theta J=F(\theta)^{-1}\nabla_\theta J

  • 最优更新方向证明:

求解带KL约束的优化问题:

\underset{\Delta\theta}{\max}J(\theta+\Delta\theta)\quad s.t. \quad \text{KL}(\pi_\theta||\pi_{\theta+\Delta\theta}) \leq \epsilon

通过拉格朗日乘子法得到自然梯度方向

3.自然策略梯度更新规则

\theta_{k+1}=\theta_k+\alpha F(\theta_k)^{-1}\nabla_\theta J(\theta_k)

实际计算技巧:

  • 使用共轭梯度法避免显示求逆

  • 增广矩阵法处理秩亏问题

兼容函数逼近定理

1.严格条件陈述

当价值函数逼近器Q_w(s,a)满足:

  1. 兼容性: \nabla_wQ_w(s,a)=\nabla_\theta\log\pi_\theta(a|s)

  2. 最小化均方误差:

w^*=\arg\underset{w}{\max}\mathbb{E}[(Q_w(s,a)-Q^\pi(s,a))^2]

则策略梯度估计无偏:

\nabla_\theta J(\theta)=\mathbb{E}[\nabla_\theta\log\pi_\theta(a|s)Q_w(s,a)]

2.证明概要

  • 条件1保证价值函数梯度与策略梯度在同一方向

  • 条件2保证 Q_wQ^\pi 在兼容子空间上的正交投影

  • 联合推导可得:\mathbb{E}[\nabla_\theta\log\pi_\theta(Q_w-Q^\pi)]=0

信任区域策略优化(TRPO)

1.核心目标与约束

优化问题:

\underset{\theta}{max}

s.t.\mathbb{E}_s[\text{KL}(\pi_{\theta_{old}}||\pi_\theta)(s)] \leq \delta

2.目标函数的局部近似

  • 优势函数近似(一阶泰勒展开):

L(\theta)\approx L(\theta_{old})+g^\top(\theta-\theta_{old})

其中 g=\nabla_\theta L|_{\theta=\theta_{old}}

  • KL散度的二阶近似:

\text{KL}(\theta_{old}||\theta)\approx \frac{1}{2}\Delta\theta^\top F(\theta_{old})\Delta\theta

F 是Fisher信息矩阵

3.解析解推导

通过拉格朗日乘子法得到最优更新方向:

\theta^*=\theta

自然梯度方向 F^{-1}g 在策略流形上是最速上升方向

4.实现中的共轭梯度法

求解 F^{-1}g 的步骤

  1. 计算Fisher-vector product:Fv=\mathbb{E}[\nabla(\nabla\log\pi)^\top v]

  2. 使用共轭梯度法迭代求解 Fx=g

  3. 通过回溯线搜索确保KL约束

近端策略优化(PPO)

1.剪切目标函数

L^{\text{CLIP}}(\theta)=\mathbb{E}_t[\min(r_t(\theta)A_t,clip(r_t(\theta),1-\epsilon,1+\epsilon)A_t)]

其中 r_t(\theta)=\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}

剪切区域分析:

  • A_t > 0,限制最大更新幅度为 (1 + \epsilon)A_t

  • A_t < 0,限制最小更新幅度为 (1-\epsilon)A_t

2.自适应KL惩罚项

目标函数:

L^{\text{KLPEN}}(\theta)=\mathbb{E}_t[r_t(\theta)A_t-\beta\text{KL}(\pi\theta_{old}||\pi_\theta)]

  • \beta 自适应规则:

\beta_{k+1}=\begin{cases}\beta_k/1.5 \quad if\:\text{KL}<\delta_{\text{low}}\\\beta_k \times2 \quad if\: \text{KL}>\delta_{\text{high}} \\\beta_k \quad \quad \:\:\:otherwise\end{cases}

典型设置:\delta_{\text{low}}=0.01, \delta_{\text{high}}=0.1

3.重要性采样方差控制

原始重要性权重方差:

\text{Var}(r_t)=\mathbb{E}[(\frac{\pi_\theta}{\pi_{old}}-1)^2]

剪切后的方差上界:

\text{Var}(r_t^{\text{clip}}\leq \epsilon^2\mathbb{E}[A_t^2])

直接偏好优化(DPO)

1.从奖励模型到策略的隐式转换

基于Bradley-Terry模型:

p^*(y_1 \succ y_2|x)=\frac{\exp(\beta\mathcal{R}(x,y_1))}{\exp(\beta\mathcal{R}(x,y_1))+\exp(\beta\mathcal{R}(x,y_2))}

关键替换:用策略表示奖励函数

\mathcal{R}(x,y)=\beta \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}+\beta\log Z(x)

2.目标函数推导

消去奖励函数后得到:

\mathcal{L}_\text{DPO}=-\mathbb{E}_(x,y_w,y_l)[\log \sigma(\beta\log\frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)}-\beta\log\frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)})]

其中 \sigma 是sigmoid函数

3.隐式KL约束分析

DPO等价于带动态约束的优化:

\underset{\theta}{\max}\mathbb{E}[\log\sigma(\beta\Delta\log\pi)]\quad s.t. \quad\text{KL}(\pi_\theta||\pi_\text{ref})\leq C

4.梯度分析

梯度计算公式:

\nabla_\theta \mathcal{L}_{\text{DPO}}=-\beta \mathbb{E}[\sigma(\hat{r_l}-\hat{r_w})(\nabla_\delta\log\pi_\theta(y_w|x))-\nabla_\delta\log\pi_\theta(y_l|x))]

其中  \hat{r_i}=\log\frac{\pi_\theta(y_i|x)}{\pi_{\text{ref}}(y_i|x)}

如果您对RL和测试时间扩展感兴趣,我自推这篇文章:从理论到实践:带你快速学习基于PRM的三种搜索方法-CSDN博客

http://www.xdnf.cn/news/2121.html

相关文章:

  • react的fiber 用法
  • 6.1腾讯技术岗2025面试趋势前瞻:大模型、云原生与安全隐私新动向
  • 重定向和语言级缓冲区【Linux操作系统】
  • 用python写一个相机选型的简易程序
  • RTMP 协议解析 1
  • Linux0.11内存管理:相关代码
  • 从零实现 registry.k8s.io/pause:3.8 镜像的导出与导入
  • 山东大学软件学院项目实训-基于大模型的模拟面试系统-网页图片显示问题
  • 基于开源技术体系的品牌赛道力重构:AI智能名片与S2B2C商城小程序源码驱动的品类创新机制研究
  • 月之暗面开源 Kimi-Audio-7B-Instruct,同时支持语音识别和语音生成
  • 推荐三款GitHub上高星开源的音乐搜索平台
  • 华为OD机试真题——素数之积RSA加密算法(2025A卷:100分)Java/python/JavaScript/C/C++/GO最佳实现
  • JDK 17 与 Spring Cloud Gateway 新特性实践指南
  • Flask + ajax上传文件(三)--图片上传与OCR识别
  • DataStreamAPI实践原理——计算模型
  • 上位机知识篇---时钟分频
  • [mysql]数据类型精讲下
  • 【Linux网络】HTTP协议全解析 - 从请求响应到方法与Header
  • SpringBoot UserAgentUtils获取用户浏览器 操作系统设备统计 信息统计 日志入库
  • 从基础到实战的量化交易全流程学习:1.1 量化交易本质与行业生态
  • C++---类和对象(二)
  • VO包装类和实体类分别是什么?区别是什么?
  • C++学习笔记(四十)——STL之归约算法
  • 深入探究 MySQL 架构:从查询到硬件
  • Apache NetBeans 25 发布
  • 任务管理系统,Java+Vue,含源码与文档,科学规划任务节点,全程督办保障项目落地提效
  • priority_queue的学习
  • GoFly快速开发框架新增UI素材库-帮助开发者快速开发管理后台UI基于ArcoDesign框架开发
  • 服务器传输数据存储数据建议 传输慢的原因
  • 文本预处理(NLTK)