强化学习DQN解决Cart_Pole问题
CartPole 环境简介
CartPole 是强化学习领域的一个经典测试环境,最早由 OpenAI 的 Gym 库引入,如今在 Gymnasium(Gym 的继任者)中仍然被广泛使用。
该环境的核心任务是:
一根竖直的杆子通过一个铰接点连接在小车上,小车可以在一维轨道上左右移动。智能体(agent)的目标是通过控制小车向左或向右的动作,保持杆子不倒下。
环境设定
-
状态空间(observation):环境在每个时刻都会返回一个长度为 4 的实数向量,包含:
- 小车位置
- 小车速度
- 杆子与竖直方向的夹角
- 杆子角速度
-
动作空间(action):离散的两个动作:
0
:小车向左移动1
:小车向右移动
-
奖励函数(reward):
每一步只要杆子没有倒下,智能体就会得到+1
的奖励。 -
终止条件(done):
当杆子与竖直方向偏离超过一定角度,或者小车位置超出轨道边界时,游戏结束。
代码结构
📦 根目录
├── 📂 agent_dqn
│ ├── 📂 algorithm
│ └── 📄 init.py
│ └── 📄 algorithm.py
│ ├── 📂 conf
│ └── 📄 init.py
│ └── 📄 conf.py
│ ├── 📂 feature
│ └── 📄 init.py
│ └── 📄 monitor.py
│ └── 📄 processor.py
│ ├── 📂 model
│ └── 📄 init.py
│ └── 📄 model.py
│ ├── 📂 workflow
│ └── 📄 init.py
│ └── 📄 train_workflow.py
│ ├── 📄 init.py
│ └── 📄 agent.py
└── 📄 train_test.py
algorithm
import math
import random
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from cart_pole.agent_dqn.conf.conf import Config
from cart_pole.agent_dqn.model.model import DQN
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorclass Algorithm:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.capacity = Config.MEMORY_SIZEself.memory = []self.push_count = 0self.epsilon = Config.EPSILON_MAXself.epsilon_max = Config.EPSILON_MAXself.epsilon_min = Config.EPSILON_MINself.epsilon_decay = Config.EPSILON_DECAY# 初始化策略网络self.model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 初始化目标网络self.target_model = DQN(Config.DIM_OF_OBSERVATION, Config.DIM_OF_ACTION).to(device)# 更新目标网络self.target_model.load_state_dict(self.model.state_dict())self.target_model.eval()# 设置优化器self.optimizer = optim.Adam(params=self.model.parameters(), lr=Config.LR)self.predict_count = 0self.train_count = 0def memory_push(self, experience) -> None:"""| This function responsible for adding experience to the| memory. Also used for sampling experiences from replay memory.IF memory less than memory initialied capacity,we're going to append inside the memoryIF NOTwe're going to begin push new experience onto the frontof memory overwriting the oldest experience.Args:experience"""if len(self.memory) < self.capacity:self.memory.append(experience)else:self.memory[self.push_count % self.capacity] = experienceself.push_count += 1def sample(self, batch_size: int):"""Sample is equal to the `batch_size` sent to this function`"""return random.sample(self.memory, batch_size)def can_provide_sample(self, batch_size: int) -> bool:"""是否可以开始采样:param batch_size::return:"""return len(self.memory) >= batch_sizedef learn(self, list_sample_data):# 将数据处理为tensorstates, actions, next_states, dones, rewards = Processor.extract_tensors(list_sample_data, self.device)# 由target_network得到target_q值self.target_model.eval()with torch.no_grad():final_states_location = next_states.flatten(start_dim=1) \.max(dim=1)[0].eq(0).type(torch.bool)non_final_states_locations = (final_states_location == False)non_final_states = next_states[non_final_states_locations]batch_size = next_states.shape[0]values = torch.zeros(batch_size).to(self.device)values[non_final_states_locations] = self.target_model(non_final_states).max(dim=1)[0].detach()target_q_values = rewards + (Config.GAMMA * values)# 得到estimate_network q值current_q_values = self.model(states).gather(dim=1, index=actions)# 计算lossloss = F.mse_loss(current_q_values, target_q_values).to(self.device)# 计算梯度loss.backward()# 梯度更新self.optimizer.step()# 梯度清0self.optimizer.zero_grad()self.train_count += 1# 更新target_networkif self.train_count % Config.TARGET_UPDATE_INTERVAL == 0:self.update_target_q()# 数据上传监控# 监控lossif self.train_count % Config.LOG_UPDATE_INTERVAL == 0:self.monitor.add_loss_info(loss.detach().item())def predict(self, obs, exploit_flag=False):self.epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) *\math.exp(-1. * self.predict_count * self.epsilon_decay)# 更新当前运行步数self.predict_count += 1# 选择动作if not exploit_flag and np.random.rand() < self.epsilon:return np.random.randint(Config.DIM_OF_ACTION)else:with torch.no_grad():obs = torch.FloatTensor(obs).unsqueeze(0).to(self.device)q_values = self.model(obs)return q_values.argmax().item()def update_target_q(self):self.target_model.load_state_dict(self.model.state_dict())
conf
from collections import namedtupleclass Config:DIM_OF_OBSERVATION = 4DIM_OF_ACTION = 2EPSILON_MAX = 1EPSILON_MIN = 0.01EPSILON_DECAY = 0.001GAMMA = 0.999LR = 0.001SEED = 234MEMORY_SIZE = 100000NUM_EPISODES = 1000TARGET_UPDATE_INTERVAL = 10LOG_UPDATE_INTERVAL = 1BATCH_SIZE = 256Experience = namedtuple('Experience',('state', 'action', 'next_state', 'done', 'reward'))ENV_RENDER_MODE = 'rgb_array'NUM_FOURIER_BASE = 1
processor
import torch
import numpy as np
from typing import NamedTuple
from cart_pole.agent_dqn.conf.conf import Configclass Processor:@staticmethoddef extract_tensors(experiences: NamedTuple, device):"""| accepts a batch of Experiences and first transposesit into an Experience of batches."""# Convert batch of Experiences to Experience of batchesbatch = Config.Experience(*zip(*experiences))t_states = torch.tensor(batch.state).to(device)t_actions = torch.tensor(batch.action).unsqueeze(-1).to(device)t_next_state = torch.tensor(batch.next_state).to(device)t_rewards = torch.tensor(batch.reward).unsqueeze(-1).to(device)t_dones = torch.tensor(batch.done).float().unsqueeze(-1).to(device)return t_states, t_actions, t_next_state, t_dones, t_rewards
monitor
import matplotlib.pyplot as plt
import seaborn as sns
import warningswarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题class Monitor:def __init__(self):self.loss_log = []self.epsilon_log = []self.reward_log = []self.episode_duration_log = []def add_loss_info(self, loss):"""向监视器添加新的loss信息"""self.loss_log.append(loss)def add_epsilon_info(self, epsilon):"""向监视器添加新的epsilon信息"""self.epsilon_log.append(epsilon)def add_reward_info(self, reward):"""向监视器添加新一轮episode的reward信息"""self.reward_log.append(reward)def add_duration_info(self, duration):"""向监视器添加新的epsilon信息"""self.episode_duration_log.append(duration)def plot_loss(self):"""绘制loss曲线"""plt.figure()plt.plot(self.loss_log)plt.xlabel('迭代次数')plt.ylabel('loss')plt.title('TD error/ loss')plt.show()def plot_epsilon(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.epsilon_log)plt.xlabel('episode')plt.ylabel('epsilon')plt.title('Epsilon Variation with Episode')plt.show()def plot_reward(self):"""绘制epsilon曲线"""plt.figure()plt.plot(self.reward_log)plt.xlabel('episode')plt.ylabel('reward')plt.title('Reward Variation with Episode')plt.show()def plot_all_log(self):"""在同一画布上绘制loss曲线、回合步长曲线和epsilon曲线"""fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 6))# 绘制loss曲线ax1.plot(self.loss_log)ax1.set_xlabel('Iteration')ax1.set_ylabel('TD error/ loss')ax1.set_title('Loss Variation with Iteration')ax1.grid(True)# 绘制回合步长曲线ax2.plot(self.episode_duration_log)ax2.set_xlabel('episode')ax2.set_ylabel('step')ax2.set_title('Step Variation with Episode')ax2.grid(True)# 绘制epsilon曲线ax3.plot(self.epsilon_log)ax3.set_xlabel('episode')ax3.set_ylabel('epsilon')ax3.set_title('Epsilon Variation with Episode')ax3.grid(True)plt.tight_layout()plt.show()
model
import torch.nn as nn
import torch.nn.functional as Fclass DQN(nn.Module):def __init__(self, num_state_features, num_actions):super().__init__()# Initialize our layers# self.fc1 = nn.Linear(in_features=img_height*img_width*3,# out_features=24)self.fc1 = nn.Linear(in_features=num_state_features,out_features=32)self.fc2 = nn.Linear(in_features=32,out_features=64)self.fc3 = nn.Linear(in_features=64,out_features=128)self.out = nn.Linear(in_features=128,out_features=num_actions) # Back to the Project overview, you can# see that total possible movements# the object can do is (<left, right>)def forward(self, t):# No Longer flatten the input# t = t.flatten(start_dim=1) # starting from the channel matrics instead of batchest = F.relu(self.fc1(t))t = F.relu(self.fc2(t))t = F.relu(self.fc3(t))# t = F.relu(self.out(t))t = self.out(t)return t
train_workflow
import timefrom tqdm import tqdmfrom cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.conf.conf import Config
from itertools import countdef run_episodes(num_episodes, env, agent: Agent, exploit_flag=False):for episode in tqdm(range(num_episodes)):# 重置任务,获取初始状态state = env.reset(seed=Config.SEED)[0]for duration in count():action = agent.algorithm.predict(state)next_state, reward, terminated, truncated, info = env.step(action)done = terminated or truncatedagent.algorithm.memory_push(Config.Experience(state, action, next_state, done, reward))state = next_stateif done:agent.monitor.add_duration_info(duration)breakif agent.algorithm.can_provide_sample(Config.BATCH_SIZE):sample_data = agent.algorithm.sample(Config.BATCH_SIZE)agent.learn(sample_data)
agent
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from cart_pole.agent_dqn.algorithm.algorithm import Algorithm
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.feature.processor import Processorwarnings.filterwarnings('ignore')
sns.set_style("whitegrid")
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题class Agent:def __init__(self, device, monitor: Monitor):self.device = deviceself.monitor = monitorself.algorithm = Algorithm(device, monitor)def predict(self, obs, exploit_flag = False):return self.algorithm.predict(obs, exploit_flag = exploit_flag)def learn(self, list_sample_data):self.algorithm.learn(list_sample_data)def save_model(self, path=None, id="1"):passdef load_model(self, path=None, id="1"):pass
train_test
import torch
from cart_pole.agent_dqn.agent import Agent
from cart_pole.agent_dqn.feature.monitor import Monitor
from cart_pole.agent_dqn.workflow.train_workflow import *
import gymnasium as gymif __name__ == "__main__":env = gym.make('CartPole-v1', render_mode="rgb_array").unwrappedmonitor = Monitor()# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')device = torch.device("cpu")agent = Agent(device, monitor)run_episodes(2000, env, agent)monitor.plot_all_log()