基于深度强化学习的Atari中的SpaceInvaders
基于深度强化学习的Atari中的SpaceInvaders
1. 故事背景
space_invaders_v2 是基于 Atari 经典街机游戏《Space Invaders》的多智能体版本。它将原始单人射击游戏扩展为双人模式,嵌入合作与竞争机制,适用于多智能体强化学习研究。
1.1 原始设定
玩家控制一艘太空飞船,驻守在地球轨道上方。外星人组成的入侵舰队从屏幕顶部逐行推进,意图摧毁地球防线。玩家必须通过发射激光击落外星人,同时利用护盾抵挡敌方炸弹。随着敌人逐渐逼近,游戏节奏加快,挑战性增强。
1.2 多智能体扩展设定
在 space_invaders_v2 中,两个玩家(智能体)分别控制各自的飞船。飞船共享生命池,但奖励机制中引入竞争元素:
- 击落敌人可得分。
- 如果对方被击中,你将获得额外奖励(200 分)。
- 玩家可以选择合作清关,也可以通过策略性“牺牲”对方来获取更高分数。
这种设定将原本的单人防御任务转化为一个混合博弈场景,既有团队协作的可能,也有策略性竞争的空间,非常适合研究多智能体之间的行为演化与策略学习。
2. 任务目标
- 控制飞船击落外星人,避免被击中。
- 与另一智能体(玩家)共享生命池,但奖励机制中存在竞争因素。
- 游戏目标是最大化自身得分,同时在有限生命内尽可能清除敌人。
3. 动作空间(Action Space)
类型:离散动作空间,共 6 个动作。
- 0: No-op(无动作)
- 1: Fire(发射子弹)
- 2: Move up(向上移动)
- 3: Move right(向右移动)
- 4: Move left(向左移动)
- 5: Move down(向下移动)
这些动作是从 Atari 控制器的最小必要动作集中提取的,确保训练效率与策略表达能力。
4. 状态空间(Observation Space)
- 类型:图像帧(RGB)
- 维度:(210, 160, 3),即 210 高 × 160 宽 × 3 通道
- 像素值范围:[0, 255]
- 每个智能体在其轮次中接收到当前游戏画面作为观察输入。
5. 奖励机制(Reward Mechanism)
- 击落外星人:每个敌人得分 5–30 分,具体取决于其位置。
- 击落飞行的 UFO:奖励 100 分。
- 对手被击中:你获得额外奖励 200 分,引入竞争激励。
- 游戏结束:当任意一方被击中 3 次,生命耗尽,游戏终止。
- 奖励是稀疏的,但通过击杀敌人和利用竞争机制可以获得密集奖励。
6. 环境参数配置(Environment Parameters)
以下是可选参数,用于调整游戏难度与策略空间:
参数名 | 类型 | 描述 |
---|---|---|
alternating_control | bool | 是否启用交替射击机制,控制权在两智能体间轮换。 |
moving_shields | bool | 护盾是否左右移动,增加防御难度。 |
zigzaging_bombs | bool | 敌人投弹是否呈 Z 字形摆动,提高规避难度。 |
fast_bomb | bool | 敌人投弹速度是否加快。 |
invisible_invaders | bool | 敌人是否不可见,极大增加挑战性。 |
render_mode | str | 可选 “human” 或 “rgb_array”,用于渲染画面或图像输出。 |
obs_type | str | “rgb” 或 “grayscale”,控制观察图像的颜色模式。 |
full_action_space | bool | 是否启用完整动作空间(默认关闭,使用最小动作集)。 |
frameskip | int | 每个动作执行的帧数(默认 4),影响游戏速度与策略响应。 |
repeat_action_probability | float | 控制动作粘性(sticky actions),用于引入随机性与防止过拟合。 |
7. 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from pettingzoo.atari import space_invaders_v2
from collections import deque
import random# 超参数
GAMMA = 0.99
LR = 1e-4
EPS_START = 1.0
EPS_END = 0.05
EPS_DECAY = 1e-5
MEMORY_SIZE = 10000
BATCH_SIZE = 32device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 简单的CNN Q网络
class QNet(nn.Module):def __init__(self, action_dim):super().__init__()self.conv = nn.Sequential(nn.Conv2d(4, 32, 8, 4), nn.ReLU(),nn.Conv2d(32, 64, 4,