当前位置: 首页 > news >正文

从代码学习深度强化学习 - 多臂老虎机 PyTorch版

文章目录

  • 前言
  • 创建多臂老虎机环境
  • 多臂老虎机算法基本框架(基类)
  • 1. ε-贪心算法 (Epsilon-Greedy)
  • 2. 随时间衰减的ε-贪婪算法 (Decaying ε-Greedy)
  • 3. 上置信界算法 (Upper Confidence Bound, UCB)
  • 4. 汤普森采样算法 (Thompson Sampling)
  • 总结


前言

欢迎来到“从代码学习深度强化学习”系列!在本篇文章中,我们将深入探讨一个强化学习中的经典问题——多臂老虎机(Multi-Armed Bandit, MAB)

多臂老虎机问题,顾名思义,源于一个赌徒在赌场面对一排老虎机(即“多臂老虎机”)的场景。每个老虎机(“臂”)都有其内在的、未知的获奖概率。赌徒的目标是在有限的回合内,通过选择拉动不同的老虎机,来最大化自己的总收益。

这看似简单的场景,却完美地诠释了强化学习中的一个核心困境:探索(Exploration)与利用(Exploitation)的权衡

  • 利用(Exploitation):选择当前已知收益最高的老虎机。这能保证我们在短期内获得不错的收益,但可能会错过一个潜在收益更高但尚未被充分尝试的选项。
  • 探索(Exploration):尝试那些我们不确定其收益的老虎机。这可能会在短期内牺牲一些收益,但却有机会发现全局最优的选择,从而获得更高的长期总回报。

为了量化算法的性能,我们引入一个重要概念——累积懊悔(Cumulative Regret)。懊悔指的是在某一步选择的动作所带来的期望收益与“上帝视角”下最优动作的期望收益之差。一个优秀的算法,其目标就是最小化在整个过程中的累积懊悔。

在本篇博客中,我们将通过 Python 代码,从零开始实现一个多臂老虎机环境,并逐步实现和对比以下四种经典的求解策略:

  1. ε-贪心算法 (Epsilon-Greedy)
  2. 随时间衰减的ε-贪心算法 (Decaying Epsilon-Greedy)
  3. 上置信界算法 (Upper Confidence Bound, UCB)
  4. 汤普森采样算法 (Thompson Sampling)

关于 PyTorch: 尽管标题提及 PyTorch,但对于 MAB 这种基础问题,使用 NumPy 能更清晰地展示算法的核心逻辑,而无需引入深度学习框架的复杂性。本文中的实现将基于 NumPy,但其核心思想(如价值估计、策略选择)是构建更复杂的深度强化学习算法(如DQN)的基石,在那些场景中 PyTorch 将发挥关键作用。

让我们开始吧!

完整代码:下载链接

创建多臂老虎机环境

首先,我们需要一个模拟环境。我们创建一个 BernoulliBandit 类来模拟一个拥有 K 个臂的老虎机。每个臂都服从伯努利分布,即每次拉动它,会以一个固定的概率 p 获得奖励 1(获奖),以 1-p 的概率获得奖励 0(未获奖)。在我们的环境中,这 K 个臂的获奖概率 p 是在初始化时随机生成的,并且对我们的算法(智能体)是未知的。

# 导入需要使用的库
import numpy as np  # numpy是支持数组和矩阵运算的科学计算库
import matplotlib.pyplot as plt  # matplotlib是绘图库class BernoulliBandit:"""伯努利多臂老虎机类该类实现了一个多臂老虎机问题的环境,每个拉杆都服从伯努利分布"""def __init__(self, K):"""初始化伯努利多臂老虎机参数:K (int): 拉杆个数,标量属性:probs (numpy.ndarray): 每个拉杆的获奖概率数组,维度为 (K,)best_idx (int): 获奖概率最大的拉杆索引,标量best_prob (float): 最大的获奖概率值,标量K (int): 拉杆总数,标量"""# 随机生成K个0~1之间的数,作为拉动每根拉杆的获奖概率# probs: (K,) - K个拉杆的获奖概率数组self.probs = np.random.uniform(size=K)# 找到获奖概率最大的拉杆索引# best_idx: 标量 - 最优拉杆的索引号self.best_idx = np.argmax(self.probs)# 获取最大的获奖概率# best_prob: 标量 - 最大获奖概率值self.best_prob = self.probs[self.best_idx]# 保存拉杆总数# K: 标量 - 拉杆个数self.K = Kdef step(self, k):"""执行一次拉杆动作当玩家选择了k号拉杆后,根据该拉杆的获奖概率返回奖励结果参数:k (int): 选择的拉杆编号,标量,取值范围为 [0, K-1]返回:int: 奖励结果,标量1 表示获奖0 表示未获奖"""# 根据k号拉杆的获奖概率进行伯努利采样# np.random.rand(): 标量 - 生成[0,1)之间的随机数# self.probs[k]: 标量 - k号拉杆的获奖概率if np.random.rand() < self.probs[k]:return 1  # 获奖else:return 0  # 未获奖# 设定随机种子,使实验具有可重复性
np.random.seed(1)# 设置拉杆数量
# K: 标量 - 多臂老虎机的拉杆个数
K = 10# 创建一个10臂伯努利老虎机实例
# bandit_10_arm: BernoulliBandit对象 - 包含10个拉杆的老虎机
bandit_10_arm = BernoulliBandit(K)# 输出老虎机的基本信息
print("随机生成了一个%d臂伯努利老虎机" % K)
print("获奖概率最大的拉杆为%d号,其获奖概率为%.4f" % (bandit_10_arm.best_idx, bandit_10_arm.best_prob))

运行以上代码,我们创建了一个10臂老虎机,并打印出了最优拉杆的信息。在我们的实验中,1号拉杆是收益最高的,其获奖概率为 0.7203。这个信息算法本身是不知道的,但我们可以用它来计算懊悔。

随机生成了一个10臂伯努利老虎机
获奖概率最大的拉杆为1号,其获奖概率为0.7203

多臂老虎机算法基本框架(基类)

为了方便实现和比较不同的算法,我们先定义一个 Solver 基类。这个基类包含了所有算法都需要共享的功能,例如记录每个臂被拉动的次数、记录历史动作以及计算和更新累积懊悔。具体的决策逻辑(run_one_step)将由各个子类来实现。

# 导入需要使用的库
import numpy as np  # numpy是支持数组和矩阵运算的科学计算库class Solver:"""多臂老虎机算法基础框架类该类为多臂老虎机问题的算法提供基本框架,包含通用的状态记录和懊悔计算功能具体的动作选择策略需要在子类中实现"""def __init__(self, bandit):"""初始化多臂老虎机算法求解器参数:bandit (BernoulliBandit): 多臂老虎机环境对象属性:bandit (BernoulliBandit): 多臂老虎机环境实例counts (numpy.ndarray): 每根拉杆的尝试次数数组,维度为 (K,)regret (float): 当前步的累积懊悔值,标量actions (list): 记录每一步动作选择的拉杆编号列表,维度为 (num_steps,)regrets (list): 记录每一步累积懊悔值的列表,维度为 (num_steps,)"""# 初始化多臂老虎机环境# bandit: BernoulliBandit对象 - 多臂老虎机环境实例self.bandit = bandit# 初始化每根拉杆的尝试次数,全部设为0# counts: (K,) - 记录每根拉杆被选择的次数self.counts = np.zeros(self.bandit.K)# 初始化累积懊悔值# regret: 标量 - 当前的累积懊悔值self.regret = 0.0# 维护一个列表,记录每一步的动作选择# actions: list,长度为num_steps - 存储每次选择的拉杆编号self.actions = []# 维护一个列表,记录每一步的累积懊悔# regrets: list,长度为num_steps - 存储每次的累积懊悔值self.regrets = []def update_regret(self, k):"""计算并更新累积懊悔值该方法采用上帝视角计算懊悔值,即已知最优拉杆的真实概率懊悔值 = 最优拉杆期望收益 - 当前选择拉杆期望收益参数:k (int): 本次动作选择的拉杆编号,标量,取值范围为 [0, K-1]"""</
http://www.xdnf.cn/news/932635.html

相关文章:

  • 【深度学习|学习笔记】自监督学习(Self-Supervised Learning, SSL)在遥感领域中的典型应用案例及其在小样本学习中的作用,附代码。
  • LeetCode --- 452周赛
  • 高保真组件库:按钮
  • GitHub 趋势日报 (2025年06月07日)
  • Langgraph实战-自省式RAG: Self-RAG
  • 材料力学速通
  • 北京工作周期7,8,9,10
  • 【react实战】如何实现监听窗口大小变化
  • 2025HNCTF - Crypto
  • webstorm 配置Eslint
  • Springboot 基于MessageSource配置国际化
  • C#调用Rust动态链接库DLL的案例
  • ​RBAC(基于角色的访问控制)权限管理详解
  • 学习日记-day24-6.8
  • 鸿蒙API自翻译
  • 70常用控件_QVBoxLayout的使用
  • 指针的使用——字符、字符串、字符串数组(char*)
  • C++进阶--C++11--智能指针(重点)
  • 12.7Swing控件6 JList
  • gitcode与github加速计划
  • LabVIEW Modbus 主站冗余控制
  • css | class中 ‘.‘ 和 ‘:‘ 的使用 | 如,何时用 .is-selected{ ... } 何时用 :hover{...}?
  • 3Ds Max 2026安装包+教程网盘下载与安装教程指南
  • [特殊字符] Whisper 模型介绍(OpenAI 语音识别系统)
  • WEB3全栈开发——面试专业技能点P1Node.js / Web3.js / Ethers.js
  • 【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
  • 图神经网络(GNN)模型的基本原理
  • MySQL:CTE 通用表达式
  • 在React 中安装和配置 shadcn/ui
  • 我用Cursor写了一个视频转文字工具,已开源,欢迎体验