当前位置: 首页 > news >正文

从代码学习深度学习 - 近似训练 PyTorch版

文章目录

  • 前言
  • 负采样 (Negative Sampling)
  • 层序Softmax (Hierarchical Softmax)
    • 代码示例
  • 总结


前言

在自然语言处理(NLP)领域,词嵌入(Word Embeddings)技术如Word2Vec(包括Skip-gram和CBOW模型)已经成为一项基础且强大的工具。它们能够将词语映射到低维稠密向量空间,使得语义相近的词在向量空间中的距离也相近。然而,这些模型在训练过程中,尤其是在计算输出层softmax时,会面临一个巨大的挑战:词汇表通常非常庞大(几十万甚至数百万个词)。对整个词典进行求和并计算梯度,其计算成本是巨大的。

为了解决这个问题,研究者们提出了多种近似训练方法,旨在降低计算复杂度,同时保持模型性能。本篇将重点介绍两种在Word2Vec中广泛应用的近似训练方法:负采样(Negative Sampling)分层Softmax(Hierarchical Softmax)。我们将以跳元模型(Skip-gram)为例来阐述这两种方法的核心思想。

虽然本文标题带有"PyTorch版",但所提供的笔记主要集中在理论层面。在实际的PyTorch应用中,这些近似训练方法通常会通过专门的损失函数或者自定义神经网络层来实现。

完整代码:下载链接

负采样 (Negative Sampling)

负采样通过修改原始目标函数来降低计算复杂度。其核心思想是,对于每个训练样本(中心词和其上下文中的一个真实目标词),我们不再尝试预测整个词汇表中哪个词是正确的上下文词,而是将其转化为一个二分类问题:区分真实的目标词和一些随机采样的“噪声”词(负样本)。

给定中心词 w c w_c wc 的上下文窗口,任意上下文词 w o w_o wo 来自该上下文窗口的事件被认为是由下式建模概率的事件:

P ( D = 1 ∣ w c , w o ) = σ ( u o ⊤ v c ) P(D=1 \mid w_c, w_o) = \sigma(\mathbf{u}_o^\top \mathbf{v}_c) P(D=1wc,wo)=σ(uovc)

其中 σ \sigma σ 使用了sigmoid激活函数的定义:

σ ( x ) = 1 1 + exp ⁡ ( − x ) \sigma(x) = \frac{1}{1 + \exp(-x)} σ(x)=1+exp(x)1

u o \mathbf{u}_o uo 是上下文词 w o w_o wo 的输出向量(或称为上下文向量), v c \mathbf{v}_c vc 是中心词 w c w_c wc 的输入向量(或称为词向量)。

原始的Word2Vec模型旨在最大化文本序列中所有这些正样本事件的联合概率。具体而言,给定长度为 T T T 的文本序列,以 w ( t ) w^{(t)} w(t) 表示时间步 t t t 的词,并使上下文窗口为 m m m,考虑最大化联合概率:

∏ t = 1 T ∏ − m ≤ j ≤ m , j ≠ 0 P ( D = 1 ∣ w ( t ) , w ( t + j ) ) \prod_{t=1}^T \prod_{-m \leq j \leq m, j \neq 0} P(D=1 \mid w^{(t)}, w^{(t+j)}) t=1Tmjm,j=0P(D=1w(t),w(t+j))

然而,这个目标函数只考虑了正样本。如果仅最大化这个概率,模型可能会学到将所有词向量都变得非常大,导致 σ ( u o ⊤ v c ) \sigma(\mathbf{u}_o^\top \mathbf{v}_c) σ(uovc) 接近1,但这并没有实际意义。

为了使目标函数更有意义,负采样引入了负样本。

S S S 表示上下文词 w o w_o wo 来自中心词 w c w_c wc 的上下文窗口的事件。对于这个涉及 w o w_o wo 的事件,我们从一个预定义的分布 P ( w ) P(w) P(w)(通常是词频的3/4次方)中采样 K K K 个不是来自这个上下文窗口的“噪声词”(负样本)。用 N k N_k Nk 表示噪声词 w k ( k = 1 , … , K ) w_k (k=1, \ldots, K) wk(k=1,,K) 不是来自 w c w_c wc 的上下文窗口的事件(即它们是负样本, D = 0 D=0 D=0)。

假设正例和负例 S , N 1 , … , N K S, N_1, \ldots, N_K S,N1,,NK 的这些事件是相互独立的。负采样将上述联合概率(仅涉及正例)修改为,对于每个中心词-上下文词对 ( w ( t ) , w ( t + j ) ) (w^{(t)}, w^{(t+j)}) (w(t),w(t+j)),最大化以下概率࿱

http://www.xdnf.cn/news/516223.html

相关文章:

  • 什么是着色器 Shader
  • fme条件属性值
  • 【LLIE专题】基于Retinex理论的transformer暗光增强
  • Spark,数据提取和保存
  • LearnOpenGL---着色器
  • 板凳-------Mysql cookbook学习 (三)
  • Qwen3数据集格式化指南:从对话模板到推理模式,结合Unsloth实战演练
  • 高压BOOST芯片-TPQ80302
  • <前端小白> 前端网页知识点总结
  • 脚本一键完成alist直接在windows上进行磁盘映射为本地磁盘webdav
  • jqGrid冻结列错行问题,将冻结表格(悬浮表格)与 正常表格进行高度同步
  • 计算机网络概要
  • Oracle 内存优化
  • 给easyui的textbox绑定回车事件
  • 翻译:20250518
  • Go 后端中双 token 的实现模板
  • 需求与实际业务需求脱节,怎么办?
  • 安卓端互动娱乐房卡系统调试实录:从UI到协议的万字深拆(第一章)
  • QT学习3
  • Socket.IO是什么?适用哪些场景?
  • 基于马尔可夫链的状态转换,用概率模型预测股市走势
  • 2025年- H31-Lc139- 242.回文链表(快慢指针)---java版--需2刷
  • 新型太空电梯——半摆卫星太空电梯 的设计与验证
  • 【Python数据处理系列】输入txt,读取特定字符转换成特定csv数据并输出
  • PointNet++:点云处理的升级版算法
  • WebSocket实时双向通信:从基础到实战
  • 3:OpenCV—视频播放
  • 彻底解决docker代理配置与无法拉取镜像问题
  • 第二章 苍穹外卖
  • Git基础原理和使用