当前位置: 首页 > ai >正文

RL代码实践 02——策略迭代

目录

一、问题描述

二、问题分析和解决

1、策略迭代算法


一、问题描述

悬崖漫步

灰色格子代表悬崖,一旦进入就游戏失败;

绿色格子是终点,一旦进入就游戏成功;

白色格子是通路。

 

已知对于每个格子,可选的动作有4种,分别往上下左右走。

奖惩规则如下:

  • 普通步长:-1(鼓励少走步)
  • 撞墙:-2(略高于普通步长,减少撞墙)
  • 终点:+50,悬崖陷阱:-100(明确目标和风险)

求每个格子的策略,使可以找到通往终点的最优路径。

二、问题分析和解决

这是一个有模型的情况:

对于特定的状态(格子)s,采取特定的动作a后,能到达的下一状态已知且概率已知,

获得的奖励已知且概率已知。

 

可以用策略迭代算法或者值迭代算法

1、策略迭代算法

首先初始化策略,主要分为两个阶段:

(1)策略评估(Policy Evaluation)

需要计算出各个状态的value。

有两个方法:一个是根据贝尔曼公式求解,另一个是用迭代求解。

这里用迭代算法求解,直到value收敛(np.abs(values - old_values) < theta))才结束。

具体来说,

即对于第k轮迭代,

某状态的 state value = 各个动作的q值 * 各个动作的概率 之和,

而q值  action value =  reward + next_state_value * 0.9。

注意此时next_state_value是上一轮,即k-1轮的值。

(2)策略改进(Policy Improvement)

get_pi() 使用贪心策略(np.argmax)选择每个状态的最优动作,生成确定性策略(概率1赋予最优动作)。

也可以采用用随机性策略:若多个动作的 Q 值相同,可均分概率。(但是值迭代算法一般默认greedy,不用随机性策略)

 

外层循环交替执行策略评估和策略改进,直到策略稳定(pi == old_pi)。

# 悬崖漫步
import numpy as np# 设置(获取)格子状态
def get_state(row, col):if row!=3 or col==0:return 'ground' # 通路if row==3 and col==11:return 'terminal' # 终点return 'trap' # 悬崖for row in range(4):for col in range(12):if get_state(row, col)=='ground':print('o', end=' ') # 不换行,用空字符结尾elif get_state(row,col)=='terminal':print('p', end=' ')else:print('x', end=' ')print()# 在特定s做特定a,求得到的下一s和r
def move(row, col, action):# 如果当前状态已经是掉进悬崖或者到达终点,直接返回(因为游戏结束了,不会再有状态转移)if get_state(row,col) in ['terminal', 'trap']:return row, col, 0  # 让它待在原地不能移动,原地不动没有奖惩# 状态转移if action == 0: # 向上走row-=1elif action == 1: # 向下走row+=1elif action == 2: # 向左走col-=1elif action == 3: # 向右走col+=1# 注意限制不能走出地图外面去out = 0 # 标记是否出界if row<0 or row>3 or col<0 or col>11:out = 1row = max(0,row)row = min(3,row)col = max(0,col)col = min(11,col)# 获得奖励reward = -1 # 普通步长if get_state(row, col)=='trap':reward = -100 # 陷阱if get_state(row, col)=='terminal':reward = 50 # 终点if out==1:reward = -2 # 出界(撞墙)return row, col, reward# 初始化state value table
values = np.zeros((4,12))
# 初始化q-table
q_table = np.zeros((4,12,4))
# 初始化策略(每个格子下采取动作的概率)
pi = np.ones((4,12,4))*0.25# 计算q
def get_q(row, col, action):# 当前rewardnext_row, next_col, reward = move(row, col, action)# 下一状态的valuenext_state_value = values[next_row, next_col]# s,a对应的action valuereturn reward + next_state_value * 0.9# 计算q-table(选择所有动作的可能性)
def get_q_table():new_q_table = np.zeros((4,12,4))# 遍历所有格子for row in range(4):for col in range(12):# 对于特定格子(状态),四个动作的q值for action in range(4):new_q_table[row, col, action] = get_q(row, col, action)return new_q_table# policy evaluation(value update)
def get_values():new_values = np.zeros((4,12))# 遍历所有格子for row in range(4):for col in range(12):# 终止状态价值为0if get_state(row, col) in ['terminal', 'trap']:new_values[row, col] = 0else:# 该状态的value = 各个动作的q值 * 各个动作的概率 之和new_values[row,col] = np.sum(q_table[row, col] * pi[row, col])return new_values# policy improvement(policy update)
def get_pi():new_pi = np.zeros((4,12,4))# 遍历所有格子for row in range(4):for col in range(12):# 终止状态无需策略if get_state(row, col) in ['terminal', 'trap']:continue# # 该状态下,有最大q值的动作有几个# max_q = np.max(q_table[row,col])# count = np.sum(q_table[row,col]==max_q)# # 让这些动作均分概率,其它为0# for action in range(4):#     if q_table[row,col,action]==max_q:#         new_pi[row, col, action]=1/count#     else:#         new_pi[row, col, action] =0# greedya = np.argmax(q_table[row, col])for action in range(4):if action == a:new_pi[row, col, action] = 1else:new_pi[row, col, action] = 0return new_pi# 循环迭代策略评估和策略提升,寻找最优解
# 增加收敛判断,可提前终止
theta = 1e-6  # 收敛阈值
for _ in range(100):old_pi = pi.copy()  # 保存旧策略old_values = values.copy()# 策略评估:直到价值函数收敛while True:q_table = get_q_table()values = get_values()if np.all(np.abs(values - old_values) < theta):break  # 价值函数收敛,结束本轮评估old_values = values.copy()  # 注意:要更新旧价值,继续迭代# 策略提升:生成新策略pi = get_pi()if np.all(pi == old_pi):break  # 策略不再变化时终止# 打印结果
for row in range(4):for col in range(12):state = get_state(row,col)if state == 'terminal':print('🚩', end=' ')  # 终点elif state == 'trap':print('🪨', end=' ')  # 悬崖else:action = np.argmax(pi[row, col])if action == 0:print('⬆️',end=' ')elif action == 1:print('⬇️', end=' ')elif action == 2:print('👈',end=' ')else:print('👉',end=' ')print()

注意:

如果当前状态已经是掉进悬崖或者到达终点(终止状态),

则在状态转移move函数中,直接返回原地状态和reward=0(因为游戏结束了,不会再有状态转移,让它待在原地不动,没有奖惩);

在get_value函数中,终止状态的value为0(因为后续一直待在原地,奖惩一直为0);

在get_pi函数中,终止状态无需策略,后面打印时打印陷阱或终点的图标即可。

http://www.xdnf.cn/news/17420.html

相关文章:

  • IDEA 如何导入系统设置
  • Go语言中切片(Slice)的拷贝
  • IDEA 快捷编辑指南
  • Mybatis学习之动态SQL(八)
  • 大型动作模型LAM:让企业重复任务实现80%效率提升的AI技术架构与实现方案
  • android 使用openimagelib OpenImage 实现点击放大图片,浏览
  • 计算机网络---IP(互联网协议)
  • MySQL(190)如何优化MySQL的网络传输?
  • Mysql系列--5、表的基本查询(下)
  • 【后端】Java 8 特性 `User::getId` 语法(方法引用)介绍
  • 五种Excel表格导出方案
  • Java学习第一百二十三部分——HTTP/HTTPS
  • 18.3 全量微调:数据预处理之清洗与准备
  • windows的cmd命令【持续更新】
  • CompletableFuture实现Excel sheet页导出
  • 微信小程序中实现表单数据实时验证的方法
  • Python中的 __name__
  • Deep Learning MNIST手写数字识别 Mac
  • 深入解析Go设计模式:命令模式实战
  • 单链表专题---暴力算法美学(2)(有视频演示)
  • Linux 系统中,如何处理信号以避免竞态条件并确保程序稳定性?
  • Oracle 19C 查看卡慢的解决思路
  • 使用快捷键将当前屏幕内容滚动到边缘@首行首列@定位到第一行第一个字符@跳转到4个角落
  • 【2025CVPR-图象去雾方向】BEVDiffuser:基于地面实况引导的BEV去噪的即插即用扩散模型
  • 诺基亚就4G/5G相关专利起诉吉利对中国汽车及蜂窝模组企业的影响
  • PHP项目运行
  • 亚麻云之数据安家——RDS数据库服务入门
  • Jenkins | 账号及权限管理
  • 从 GPT‑2 到 gpt‑oss:解析架构的迭代
  • 在windows安装colmap并在cmd调用