当前位置: 首页 > web >正文

再读强化学习(动态规划)

动态规划(Dynamic Programming)是基本算法的一种,和贝尔曼方程形式非常相似,所以用动态规划解贝尔曼方程是合适的。

贝尔曼方程推导:

V(s)=\mathbb{E}\big[\,G_t\mid s_t=s\,\big]

V(s) = \mathbb{E}_{\pi} \left[ r_{t} + \gamma G_{t+1} \mid s_t = s \right]

V(s) = \mathbb{E}\!\left[\, r_{t+1} + \gamma V(s_{t+1}) \,\middle|\, s_t = s \right]

在贝尔曼方程中可以看到,求解一个state的state value是要借助下一个状态的state value,不断循环,类似于树形结构,自下而上求解,要求解顶部的值,需要从底层的子问题求解,层层递进。

贝尔曼方程是基于markov reward process(MRP),在MRP基础之上,一般来说引入动作。也就是说在不同的state采取不同的action进入到next state,而非是客观地直接从一个state转换到next state。者也就是markov decision process(MDP)。因为引入了主观选择的步骤,所以也就出现了贝尔曼期望方程与贝尔曼最优方程如下:

贝尔曼期望方程(Bellman Expectation Equation):

V^\pi(s) = \sum_{a} \pi(a \mid s) \left[ R(s,a) + \gamma \sum_{s'} P(s' \mid s,a)\, V^\pi(s') \right]

Q^\pi(s,a) = R(s,a) + \gamma \sum_{s'} P(s' \mid s,a)\, \sum_{a'} \pi(a' \mid s')\, Q^\pi(s',a')

贝尔曼最优方程(Bellman Optimal Equation):

V^*(s) = \max_{a} \sum_{s'} P(s' \mid s, a) \Big[ R(s, a) + \gamma V^*(s') \Big]

Q^*(s, a) = \sum_{s'} P(s' \mid s, a) \Big[ R(s, a) + \gamma \max_{a'} Q^*(s', a') \Big]

因为引入了action,而action也是可以有多种的,所以在之前添加了在某个state下所有action情况的求和,以及policy作出这个action的概率。后面其实都一样,只不过transition matrix在简单的state变换到next state的概率基础上引入了action的影响。但是我们可以看到,思路还是一样的,适合用DP思路解决。

基于MDP与贝尔曼期望/最优方程,结合DP思想,我们可以设计出model-based dynamic programming algrtihm:价值迭代(value iteration)与策略迭代(policy iteration)。这两个算法思路都很一致,先是根据贝尔曼期望方程policy/value evaluation,不断循环直到所有的state value在当前策略下处于收敛/稳定的状态下。根据已经得到state values(注意这里有s,是复数)与贝尔曼最优方程我们可以进行策略提升。evaluation需要根据期望方程迭代的原因是求期望,所以需要多条轨迹不断计算更新直到遍历结束,求到真实期望。(这在robotic manipulation领域是不现实的,没有办法得到所有轨迹也没有办法得到transition matrix,所以model-free的算法更为主流,当然也有model-based算法,但是这些算法区别于DP算法是用神经网络做转移概率的approximation)。

基于动手学强化学习的代码,提供的value iteration和policy iteration代码如下:

import copyclass CliffWalkingEnv:""" 悬崖漫步环境"""def __init__(self, ncol=12, nrow=4):self.ncol = ncol  # 定义网格世界的列self.nrow = nrow  # 定义网格世界的行# 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励self.P = self.createP()def createP(self):# 初始化P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]# 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)# 定义在左上角change = [[0, -1], [0, 1], [-1, 0], [1, 0]]for i in range(self.nrow):for j in range(self.ncol):for a in range(4):# 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0if i == self.nrow - 1 and j > 0:P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0,True)]continue# 其他位置next_x = min(self.ncol - 1, max(0, j + change[a][0]))next_y = min(self.nrow - 1, max(0, i + change[a][1]))next_state = next_y * self.ncol + next_xreward = -1done = False# 下一个位置在悬崖或者终点if next_y == self.nrow - 1 and next_x > 0:done = Trueif next_x != self.ncol - 1:  # 下一个位置在悬崖reward = -100P[i * self.ncol + j][a] = [(1, next_state, reward, done)]return Pclass PolicyIteration():"""Policy iteration algorithm including policy evaluation and improvement.hyperparmaeters:1. env. Environment2. gamma. Discount factor3. theta. Converge terminationstructure:1. Policy evaluation2. Policy improvement3. Policy iteration"""def __init__(self, env, gamma, theta):self.env = envself.gamma = gammaself.theta = thetaself.state_dim = env.nrow * env.ncolself.action_dim = 4self.pi = [[0.25, 0.25, 0.25, 0.25] for i in range(self.state_dim)]self.state_value = [0] * self.state_dimdef policy_evaluation(self):counter = 0while True:difference = 0new_state_values = [0] * self.state_dimfor state in range(self.state_dim):qsa = []for action in range(self.action_dim):reset_states_value = 0for situation in self.env.P[state][action]:transition_P, next_state, reward, done = situationreset_states_value += transition_P * (self.gamma * self.state_value[next_state] * (1 - done))state_action_value = reward + reset_states_valueqsa.append(self.pi[state][action] * state_action_value)# new_state_values[state] += self.pi[state][action] * state_action_valuenew_state_values[state] = sum(qsa)difference = max(difference, abs(new_state_values[state] - self.state_value[state]))self.state_value = new_state_valuescounter += 1if counter > 1e4 or difference < self.theta:print("Policy evaluation finished: %d rounds" % counter)breakdef policy_improvement(self):for state in range(self.state_dim):state_action_value_list = []for action in range(self.action_dim):state_value = 0for situation in self.env.P[state][action]:transition_P, next_state, reward, done = situationstate_value += self.gamma * transition_P * self.state_value[next_state] * (1-done)state_action_value = reward + state_valuestate_action_value_list.append(state_action_value)max_value = max(state_action_value_list)max_number = state_action_value_list.count(max_value)self.pi[state] = [1 / max_number if value == max_value else 0 for value in state_action_value_list]print("Policy improvement finished. ")print("                             ")return self.pidef policy_iteration(self):while True:self.policy_evaluation()old_pi = copy.deepcopy(self.pi)new_pi = self.policy_improvement()if old_pi == self.pi:breakdef print_agent(agent, action_meaning, disaster=[], end=[]):print("状态价值:")for i in range(agent.env.nrow):for j in range(agent.env.ncol):# 为了输出美观,保持输出6个字符print('%6.6s' % ('%.3f' % agent.state_value[i * agent.env.ncol + j]), end=' ')print()print("策略:")for i in range(agent.env.nrow):for j in range(agent.env.ncol):# 一些特殊的状态,例如悬崖漫步中的悬崖if (i * agent.env.ncol + j) in disaster:print('****', end=' ')elif (i * agent.env.ncol + j) in end:  # 目标状态print('EEEE', end=' ')else:a = agent.pi[i * agent.env.ncol + j]pi_str = ''for k in range(len(action_meaning)):pi_str += action_meaning[k] if a[k] > 0 else 'o'print(pi_str, end=' ')print()if __name__ == "__main__":env = CliffWalkingEnv()action_meaning = ['^', 'v', '<', '>']gamma = 0.9theta = 1e-3agent = PolicyIteration(env=env,gamma=gamma,theta=theta,)agent.policy_iteration()print_agent(agent, action_meaning, list(range(37, 47)), [47])
import copy
import numpy as npclass CliffWalkingEnv:""" 悬崖漫步环境"""def __init__(self, ncol=12, nrow=4):self.ncol = ncol  # 定义网格世界的列self.nrow = nrow  # 定义网格世界的行# 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励self.P = self.createP()def createP(self):# 初始化P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]# 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)# 定义在左上角change = [[0, -1], [0, 1], [-1, 0], [1, 0]]for i in range(self.nrow):for j in range(self.ncol):for a in range(4):# 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0if i == self.nrow - 1 and j > 0:P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0,True)]continue# 其他位置next_x = min(self.ncol - 1, max(0, j + change[a][0]))next_y = min(self.nrow - 1, max(0, i + change[a][1]))next_state = next_y * self.ncol + next_xreward = -1done = False# 下一个位置在悬崖或者终点if next_y == self.nrow - 1 and next_x > 0:done = Trueif next_x != self.ncol - 1:  # 下一个位置在悬崖reward = -100P[i * self.ncol + j][a] = [(1, next_state, reward, done)]return Pclass ValueIteration():def __init__(self, env, gamma, theta):self.env = envself.gamma = gammaself.theta = thetaself.state_numbers = env.ncol * env.nrowself.action_numbers = 4self.state_values = [0] * self.state_numbersself.pi = [[0.25, 0.25, 0.25, 0.25] for i in range(self.state_numbers)]def value_evaluation(self):counter = 0while True:new_state_values = copy.deepcopy(self.state_values)for state in range(self.state_numbers):action_values = []for action in range(self.action_numbers):next_state_value = 0for situation in self.env.P[state][action]:transition_P, next_state, reward, done = situationnext_state_value += transition_P * (reward + self.gamma * self.state_values[next_state] * (1 - done))# action_value = reward + next_state_valueaction_values.append(next_state_value)new_state_values[state] = max(action_values)difference = max(abs(np.array(self.state_values) - np.array(new_state_values)))self.state_values = new_state_values# print(difference)counter += 1if counter >= 1e3 or difference < self.theta:print("Value evaluation finished: %d rounds." % counter)breakdef value_improvement(self):for state in range(self.state_numbers):action_values = []for action in range(self.action_numbers):next_states_value = 0for situation in self.env.P[state][action]:transition_P, next_state, reward, done = situationnext_states_value += transition_P *(reward + self.gamma * self.state_values[next_state] * (1 - done))action_values.append(next_states_value)max_action_value = max(action_values)max_action_value_counter = action_values.count(max_action_value)self.pi[state] = [1 / max_action_value_counter if value == max_action_value else 0 for value in action_values]def value_iteration(self):while True:self.value_evaluation()old_pi = copy.deepcopy(self.pi)self.value_improvement()break# if old_pi == self.pi:#     breakdef print_agent(agent, action_meaning, disaster=[], end=[]):print("状态价值:")for i in range(agent.env.nrow):for j in range(agent.env.ncol):# 为了输出美观,保持输出6个字符print('%6.6s' % ('%.3f' % agent.state_values[i * agent.env.ncol + j]), end=' ')print()print("策略:")for i in range(agent.env.nrow):for j in range(agent.env.ncol):# 一些特殊的状态,例如悬崖漫步中的悬崖if (i * agent.env.ncol + j) in disaster:print('****', end=' ')elif (i * agent.env.ncol + j) in end:  # 目标状态print('EEEE', end=' ')else:a = agent.pi[i * agent.env.ncol + j]pi_str = ''for k in range(len(action_meaning)):pi_str += action_meaning[k] if a[k] > 0 else 'o'print(pi_str, end=' ')print()if __name__ == "__main__":env = CliffWalkingEnv()gamma = 0.9theta = 0.001action_meaning = ['^', 'v', '<', '>']agent = ValueIteration(env=env, gamma=gamma, theta=theta,)agent.value_iteration()print_agent(agent, action_meaning, list(range(37, 47)), [47])

http://www.xdnf.cn/news/20339.html

相关文章:

  • 时隔4年麒麟重新登场!华为这8.8英寸新「手机」给我看麻了
  • 《Ceph集群数据同步异常的根因突破与恢复实践》
  • 深入剖析RocketMQ分布式消息架构:从入门到精通的技术全景解析
  • Ubuntu 文件权限管理
  • 【正则表达式】选择(Alternation)和分支 (Branching)在正则表达式中的使用
  • MySQL InnoDB 的锁机制
  • Chrome 插件开发入门:打造个性化浏览器扩展
  • 神经网络|(十八)概率论基础知识-伽马函数·下
  • Follow 幂如何刷屏?拆解淘宝闪购×杨幂的情绪共振品牌营销
  • Doris 消费kafka消息
  • 通过PXE的方式实现Ubuntu 24.04 自动安装
  • 版本管理系统与平台(权威资料核对、深入解析、行业选型与国产平台补充)
  • 50.4k Star!我用这个神器,在五分钟内搭建了一个私有 Git 服务器!
  • 小程序的project.private.config.json是无依赖文件,那可以删除吗?
  • Aspose.Words for .NET 25.7:支持自建大语言模型(LLM),实现更安全灵活的AI文档处理功能
  • 《LangChain从入门到精通》系统学习教材大纲
  • java基础学习(四):类 - 了解什么是类,类中都有什么?
  • 25年下载chromedriver.140
  • 项目必备流程图,类图,E-R图实例速通
  • 面试 TOP101 贪心专题题解汇总Java版(BM95 —— BM96)
  • 实力登榜!美创科技荣膺数说安全《2025中国网络安全企业100强》
  • IDEA中Transaction翻译插件无法使用,重新配置Transaction插件方法
  • 基于飞算JavaAI的在线图书借阅平台设计实现
  • Process Explorer 学习笔记(第三章 3.2.2):定制可显示的列与数据保存
  • Linux 入门到精通,真的不用背命令!零基础小白靠「场景化学习法」,3 个月拿下运维 offer,第二十七天
  • Bug排查日记:从崩溃到修复的实战记录
  • Nginx +Tomcat架构的必要性与应用示例
  • Kafka 消息队列:揭秘海量数据流动的技术心脏
  • 具身智能多模态感知与场景理解:融合语言模型的多模态大模型
  • 【关系型数据库SQL】MySql数据库基础学习(一)