从代码学习深度学习 - 近似训练 PyTorch版
文章目录
- 前言
- 负采样 (Negative Sampling)
- 层序Softmax (Hierarchical Softmax)
- 代码示例
- 总结
前言
在自然语言处理(NLP)领域,词嵌入(Word Embeddings)技术如Word2Vec(包括Skip-gram和CBOW模型)已经成为一项基础且强大的工具。它们能够将词语映射到低维稠密向量空间,使得语义相近的词在向量空间中的距离也相近。然而,这些模型在训练过程中,尤其是在计算输出层softmax时,会面临一个巨大的挑战:词汇表通常非常庞大(几十万甚至数百万个词)。对整个词典进行求和并计算梯度,其计算成本是巨大的。
为了解决这个问题,研究者们提出了多种近似训练方法,旨在降低计算复杂度,同时保持模型性能。本篇将重点介绍两种在Word2Vec中广泛应用的近似训练方法:负采样(Negative Sampling)和分层Softmax(Hierarchical Softmax)。我们将以跳元模型(Skip-gram)为例来阐述这两种方法的核心思想。
虽然本文标题带有"PyTorch版",但所提供的笔记主要集中在理论层面。在实际的PyTorch应用中,这些近似训练方法通常会通过专门的损失函数或者自定义神经网络层来实现。
完整代码:下载链接
负采样 (Negative Sampling)
负采样通过修改原始目标函数来降低计算复杂度。其核心思想是,对于每个训练样本(中心词和其上下文中的一个真实目标词),我们不再尝试预测整个词汇表中哪个词是正确的上下文词,而是将其转化为一个二分类问题:区分真实的目标词和一些随机采样的“噪声”词(负样本)。
给定中心词 w c w_c wc 的上下文窗口,任意上下文词 w o w_o wo 来自该上下文窗口的事件被认为是由下式建模概率的事件:
P ( D = 1 ∣ w c , w o ) = σ ( u o ⊤ v c ) P(D=1 \mid w_c, w_o) = \sigma(\mathbf{u}_o^\top \mathbf{v}_c) P(D=1∣wc,wo)=σ(uo⊤vc)
其中 σ \sigma σ 使用了sigmoid激活函数的定义:
σ ( x ) = 1 1 + exp ( − x ) \sigma(x) = \frac{1}{1 + \exp(-x)} σ(x)=1+exp(−x)1
u o \mathbf{u}_o uo 是上下文词 w o w_o wo 的输出向量(或称为上下文向量), v c \mathbf{v}_c vc 是中心词 w c w_c wc 的输入向量(或称为词向量)。
原始的Word2Vec模型旨在最大化文本序列中所有这些正样本事件的联合概率。具体而言,给定长度为 T T T 的文本序列,以 w ( t ) w^{(t)} w(t) 表示时间步 t t t 的词,并使上下文窗口为 m m m,考虑最大化联合概率:
∏ t = 1 T ∏ − m ≤ j ≤ m , j ≠ 0 P ( D = 1 ∣ w ( t ) , w ( t + j ) ) \prod_{t=1}^T \prod_{-m \leq j \leq m, j \neq 0} P(D=1 \mid w^{(t)}, w^{(t+j)}) ∏t=1T∏−m≤j≤m,j=0P(D=1∣w(t),w(t+j))
然而,这个目标函数只考虑了正样本。如果仅最大化这个概率,模型可能会学到将所有词向量都变得非常大,导致 σ ( u o ⊤ v c ) \sigma(\mathbf{u}_o^\top \mathbf{v}_c) σ(uo⊤vc) 接近1,但这并没有实际意义。
为了使目标函数更有意义,负采样引入了负样本。
用 S S S 表示上下文词 w o w_o wo 来自中心词 w c w_c wc 的上下文窗口的事件。对于这个涉及 w o w_o wo 的事件,我们从一个预定义的分布 P ( w ) P(w) P(w)(通常是词频的3/4次方)中采样 K K K 个不是来自这个上下文窗口的“噪声词”(负样本)。用 N k N_k Nk 表示噪声词 w k ( k = 1 , … , K ) w_k (k=1, \ldots, K) wk(k=1,…,K) 不是来自 w c w_c wc 的上下文窗口的事件(即它们是负样本, D = 0 D=0 D=0)。
假设正例和负例 S , N 1 , … , N K S, N_1, \ldots, N_K S,N1,…,NK 的这些事件是相互独立的。负采样将上述联合概率(仅涉及正例)修改为,对于每个中心词-上下文词对 ( w ( t ) , w ( t + j ) ) (w^{(t)}, w^{(t+j)}) (w(t),w(t+j)),最大化以下概率