当前位置: 首页 > news >正文

强化学习DQN解决Cart_Pole问题

CartPole 环境简介

CartPole 是强化学习领域的一个经典测试环境,最早由 OpenAI 的 Gym 库引入,如今在 Gymnasium(Gym 的继任者)中仍然被广泛使用。

该环境的核心任务是:
一根竖直的杆子通过一个铰接点连接在小车上,小车可以在一维轨道上左右移动。智能体(agent)的目标是通过控制小车向左或向右的动作,保持杆子不倒下。

环境设定

  • 状态空间(observation):环境在每个时刻都会返回一个长度为 4 的实数向量,包含:

    1. 小车位置
    2. 小车速度
    3. 杆子与竖直方向的夹角
    4. 杆子角速度
  • 动作空间(action):离散的两个动作:

    • 0:小车向左移动
    • 1:小车向右移动
  • 奖励函数(reward)
    每一步只要杆子没有倒下,智能体就会得到 +1 的奖励。

  • 终止条件(done)
    当杆子与竖直方向偏离超过一定角度,或者小车位置超出轨道边界时,游戏结束。

代码结构

📦 根目录
├── 📂 agent_dqn
│ ├── 📂 algorithm
│   └── 📄 init.py
│   └── 📄 algorithm.py
│ ├── 📂 conf
│   └── 📄 init.py
│   └── 📄 conf.py
│ ├── 📂 feature
│   └── 📄 init.py
│   └── 📄 monitor.py
│   └── 📄 processor.py
│ ├── 📂 model
│   └── 📄 init.py
│   └── 📄 model.py
│ ├── 📂 workflow
│   └── 📄 init.py
│   └── 📄 train_workflow.py
│ ├── 📄 init.py
│ └── 📄 agent.py
└── 📄 train_test.py

algorithm

import math
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from cart_pole.agent_dqn.conf.conf import Config
from cart_pole.agent_dqn.model.model import DQN
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorclass Algorithm:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.capacity = Config.MEMORY_SIZEself.memory = []self.push_count = 0self.epsilon = Config.EPSILON_MAXself.epsilon_max = Config.EPSILON_MAXself.epsilon_min = Config.EPSILON_MINself.epsilon_decay = Config.EPSILON_DECAY# 初始化策略网络self.model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 初始化目标网络self.target_model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 更新目标网络self.target_model.load_state_dict(self.model.state_dict())self.target_model.eval()# 设置优化器self.optimizer = optim.Adam(params=self.model.parameters(), lr=Config.LR)self.predict_count = 0self.train_count = 0def memory_push(self, experience) -> None:"""| This function responsible for adding experience to the| memory. Also used for sampling experiences from replay memory.IF memory less than memory initialied capacity,we're going to append inside the memoryIF NOTwe're going to begin push new experience onto the frontof memory overwriting the oldest experience.Args:experience"""if len(self.memory) < self.capacity:self.memory.append(experience)else:self.memory[self.push_count % self.capacity] = experienceself.push_count += 1def sample(self, batch_size: int):"""Sample is equal to the `batch_size` sent to this function`"""return random.sample(self.memory, batch_size)def can_provide_sample(self, batch_size: int) -> bool:"""是否可以开始采样:param batch_size::return:"""return len(self.memory) >= batch_sizedef learn(self, list_sample_data):# 将数据处理为tensorstates, actions, next_states, dones, rewards = Processor.extract_tensors(list_sample_data, self.device)# 由target_network得到target_q值self.target_model.eval()with torch.no_grad():final_states_location = next_states.flatten(start_dim=1) \.max(dim=1)[0].eq(0).type(torch.bool)non_final_states_locations = (final_states_location == False)non_final_states = next_states[non_final_states_locations]batch_size = next_states.shape[0]values = torch.zeros(batch_size).to(self.device)values[non_final_states_locations] = self.target_model(non_final_states).max(dim=1)[0].detach()target_q_values = rewards + (Config.GAMMA * values)# 得到estimate_network q值current_q_values = self.model(states).gather(dim=1, index=actions)# 计算lossloss = F.mse_loss(current_q_values, target_q_values).to(self.device)# 计算梯度loss.backward()# 梯度更新self.optimizer.step()# 梯度清0self.optimizer.zero_grad()self.train_count += 1# 更新target_networkif self.train_count % Config.TARGET_UPDATE_INTERVAL == 0:self.update_target_q()# 数据上传监控# 监控lossif self.train_count % Config.LOG_UPDATE_INTERVAL == 0:self.monitor.add_loss_info(loss.detach().item())def predict(self, obs, exploit_flag=False):self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) *\math.exp(-1. * self.predict_count * self.epsilon_decay)# 更新当前运行步数self.predict_count += 1# 选择动作if not exploit_flag and np.random.rand() < self.epsilon:return np.random.randint(Config.DIM_OF_ACTION)else:with torch.no_grad():obs = torch.FloatTensor(obs).unsqueeze(0).to(self.device)q_values = self.model(obs)return q_values.argmax().item()def update_target_q(self):self.target_model.load_state_dict(self.model.state_dict())

conf

from collections import namedtupleclass Config:DIM_OF_OBSERVATION = 4DIM_OF_ACTION = 2EPSILON_MAX = 1EPSILON_MIN = 0.01EPSILON_DECAY = 0.001GAMMA = 0.999LR = 0.001SEED = 234MEMORY_SIZE = 100000NUM_EPISODES = 1000TARGET_UPDATE_INTERVAL = 10LOG_UPDATE_INTERVAL = 1BATCH_SIZE = 256Experience = namedtuple('Experience',('state', 'action', 'next_state', 'done', 'reward'))ENV_RENDER_MODE = 'rgb_array'NUM_FOURIER_BASE = 1

processor

import torch
import numpy as np
from typing import NamedTuple
from cart_pole.agent_dqn.conf.conf import Configclass Processor:@staticmethoddef extract_tensors(experiences: NamedTuple, device):"""| accepts a batch of Experiences and first transposesit into an Experience of batches."""# Convert batch of Experiences to Experience of batchesbatch = Config.Experience(*zip(*experiences))t_states = torch.tensor(batch.state).to(device)t_actions = torch.tensor(batch.action).unsqueeze(-1).to(device)t_next_state = torch.tensor(batch.next_state).to(device)t_rewards = torch.tensor(batch.reward).unsqueeze(-1).to(device)t_dones = torch.tensor(batch.done).float().unsqueeze(-1).to(device)return t_states, t_actions, t_next_state, t_dones, t_rewards

monitor

import matplotlib.pyplot as plt
import seaborn as sns
import warningswarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题class Monitor:def __init__(self):self.loss_log = []self.epsilon_log = []self.reward_log = []self.episode_duration_log = []def add_loss_info(self, loss):"""向监视器添加新的loss信息"""self.loss_log.append(loss)def add_epsilon_info(self, epsilon):"""向监视器添加新的epsilon信息"""self.epsilon_log.append(epsilon)def add_reward_info(self, reward):"""向监视器添加新一轮episode的reward信息"""self.reward_log.append(reward)def add_duration_info(self, duration):"""向监视器添加新的epsilon信息"""self.episode_duration_log.append(duration)def plot_loss(self):"""绘制loss曲线"""plt.figure()plt.plot(self.loss_log)plt.xlabel('迭代次数')plt.ylabel('loss')plt.title('TD error/ loss')plt.show()def plot_epsilon(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.epsilon_log)plt.xlabel('episode')plt.ylabel('epsilon')plt.title('Epsilon Variation with Episode')plt.show()def plot_reward(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.reward_log)plt.xlabel('episode')plt.ylabel('reward')plt.title('Reward Variation with Episode')plt.show()def plot_all_log(self):"""在同一画布上绘制loss曲线、回合步长曲线和epsilon曲线"""fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 6))# 绘制loss曲线ax1.plot(self.loss_log)ax1.set_xlabel('Iteration')ax1.set_ylabel('TD error/ loss')ax1.set_title('Loss Variation with Iteration')ax1.grid(True)# 绘制回合步长曲线ax2.plot(self.episode_duration_log)ax2.set_xlabel('episode')ax2.set_ylabel('step')ax2.set_title('Step Variation with Episode')ax2.grid(True)# 绘制epsilon曲线ax3.plot(self.epsilon_log)ax3.set_xlabel('episode')ax3.set_ylabel('epsilon')ax3.set_title('Epsilon Variation with Episode')ax3.grid(True)plt.tight_layout()plt.show()

model

import torch.nn as nn
import torch.nn.functional as Fclass DQN(nn.Module):def __init__(self, num_state_features, num_actions):super().__init__()# Initialize our layers# self.fc1 = nn.Linear(in_features=img_height*img_width*3,#                      out_features=24)self.fc1 = nn.Linear(in_features=num_state_features,out_features=32)self.fc2 = nn.Linear(in_features=32,out_features=64)self.fc3 = nn.Linear(in_features=64,out_features=128)self.out = nn.Linear(in_features=128,out_features=num_actions)  # Back to the Project overview, you can# see that total possible movements# the object can do is (<left, right>)def forward(self, t):# No Longer flatten the input# t = t.flatten(start_dim=1) # starting from the channel matrics instead of batchest = F.relu(self.fc1(t))t = F.relu(self.fc2(t))t = F.relu(self.fc3(t))# t = F.relu(self.out(t))t = self.out(t)return t

train_workflow

import timefrom tqdm import tqdmfrom cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.conf.conf import Config
from itertools import countdef run_episodes(num_episodes, env, agent: Agent, exploit_flag=False):for episode in tqdm(range(num_episodes)):# 重置任务,获取初始状态state = env.reset(seed=Config.SEED)[0]for duration in count():action = agent.algorithm.predict(state)next_state, reward, terminated, truncated, info = env.step(action)done = terminated or truncatedagent.algorithm.memory_push(Config.Experience(state, action, next_state, done, reward))state = next_stateif done:agent.monitor.add_duration_info(duration)breakif agent.algorithm.can_provide_sample(Config.BATCH_SIZE):sample_data = agent.algorithm.sample(Config.BATCH_SIZE)agent.learn(sample_data)

agent

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from cart_pole.agent_dqn.algorithm.algorithm import Algorithm
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorwarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题class Agent:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.algorithm = Algorithm(device, monitor)def predict(self, obs, exploit_flag = False):return self.algorithm.predict(obs, exploit_flag = exploit_flag)def learn(self, list_sample_data):self.algorithm.learn(list_sample_data)def save_model(self, path=None, id="1"):passdef load_model(self, path=None, id="1"):pass

train_test

import torch
from cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.workflow.train_workflow import *
import gymnasium as gymif __name__ == "__main__":env = gym.make('CartPole-v1', render_mode="rgb_array").unwrappedmonitor = Monitor()# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')device = torch.device("cpu")agent = Agent(device, monitor)run_episodes(2000, env, agent)monitor.plot_all_log()
http://www.xdnf.cn/news/1459369.html

相关文章:

  • claude code route 使用教程|命令大全
  • linux中的awk使用详解
  • 深度解读《实施“人工智能+”行动的意见》:一场由场景、数据与价值链共同定义的产业升级
  • 【8】C#上位机---泛型、委托delegate与多线程Task
  • 2025年代理IP服务深度评测:三大平台横评,谁是最强业务助手?
  • 检查数据集格式(77)
  • 计算机二级C语言操作题(填空、修改、设计题)——真题库(16)附解析答案
  • C++基础——模板进阶
  • 【C++题解】关联容器
  • Linux的权限详解
  • 一次死锁的排查
  • 激活函数:神经网络的“灵魂开关”
  • 阅读论文神奇Zotero下载安装教程以及划词翻译(Translate for Zotero)的配置
  • 动态内存管理柔性数组
  • Vue 中绑定样式的几种方式
  • Process Explorer 学习笔记(第三章3.1.1):度量 CPU 的使用情况详解
  • 【Unity知识分享】Unity接入dll调用Window系统接口
  • 无限时长视频生成新突破!复旦联合微软、腾讯混元推出StableAvatar,仅需1张照片+1段音频实现真人说话视频
  • hutool的EnumUtil工具类实践【持续更新】
  • 揭秘23种设计模式的艺术与技巧之行为型
  • 美联储计划召开稳定币和代币化创新会议
  • 大数据框架Doris全面解析
  • 期权平仓后权利金去哪了?
  • 基于STM32的智能家居语音控制系统设计
  • Pycharm终端pip install的包都在C:\Users\\AppData\Roaming\Python\解决办法
  • 手写Spring框架
  • 前端跨域终极指南:3 种优雅解决方案 + 可运行 Demo
  • 解密注意力机制:为何它能在Transformer中实现高效并行计算?
  • STM32G4 速度环开环,电流环闭环 IF模式建模
  • 如何在Linux上部署1Panel面板并远程访问内网Web端管理界面