从代码学习深度学习 - 自然语言推断:使用注意力 PyTorch版
文章目录
- 前言
- 模型详解
- 第一步:注意(Attending)
- MLP 辅助函数
- Attend 模块
- 第二步:比较(Comparing)
- Compare 模块
- 第三步:聚合(Aggregating)
- Aggregate 模块
- 整合模型:DecomposableAttention
- 模型训练与评估
- 数据准备
- 数据处理工具集
- 加载数据
- 模型创建与初始化
- 词向量与设备工具
- 实例化与初始化
- 训练过程
- 训练工具集
- 执行训练
- 使用模型进行预测
- 总结
前言
自然语言推断(Natural Language Inference, NLI)是自然语言处理(NLP)领域一个核心且富有挑战性的任务。它的目标是判断两句话之间的逻辑关系,通常分为蕴涵(Entailment)、矛盾(Contradiction)和中性(Neutral)三类。例如,给定前提“一个人在骑马”,我们希望模型能推断出假设“一个人在户外”是蕴涵关系,而“一个人在睡觉”是矛盾关系。
传统的NLI模型常常依赖于复杂的深度网络结构,如循环神经网络(RNN)或卷积神经网络(CNN)。然而,在2016年,Parikh等人提出了一种新颖且高效的“可分解注意力模型”(Decomposable Attention Model),该模型完全摒弃了循环和卷积层,仅通过注意力机制和简单的多层感知机(MLP),就在当时权威的SNLI数据集上取得了顶尖的性能,并且参数量更少。
这篇博客将带领大家深入探索这一经典模型的PyTorch实现。我们将从零开始,逐一剖析模型的三个核心步骤——注意(Attending)、比较(Comparing)和聚合(Aggregating),并通过详尽的代码和注释,帮助你彻底理解其工作原理。无论你是NLP初学者还是希望温故知新的开发者,相信本文都能为你带来启发。
完整代码:下载链接
模型详解
可分解注意力模型的整体思想非常直观:它不依赖于对句子时序信息的复杂编码,而是将一个句子中的每个词元与另一个句子中的所有词元进行“对齐”,然后比较这些对齐的信息,最后将所有比较结果汇总起来,做出最终的逻辑判断。
如上图所示,整个模型由三个联合训练的阶段构成:注意、比较和聚合。接下来,我们将逐一深入代码实现。
第一步:注意(Attending)
“注意”阶段的核心任务是建立前提(Premise)和假设(Hypothesis)中词元之间的“软对齐”(Soft Alignment)。例如,对于前提“我确实需要睡眠”和假设“我累了”,我们希望模型能够自动地将假设中的“我”与前提中的“我”对齐,并将“累”与“睡眠”对齐。这种对齐是通过注意力权重实现的。
我们用 A = ( a 1 , … , a m ) \mathbf{A} = (\mathbf{a}_1, \ldots, \mathbf{a}_m) A=(a1,…,am) 和 B = ( b 1 , … , b n ) \mathbf{B} = (\mathbf{b}_1, \ldots, \mathbf{b}_n) B=(b1,…,bn) 分别表示前提和假设的词向量序列。首先,我们将每个词向量通过一个共享的MLP网络 f f f 进行变换。然后,前提中第 i i i 个词元和假设中第 j j j 个词元的注意力分数 e i j e_{ij} eij 计算如下:
e i j = f ( a i ) ⊤ f ( b j ) e_{ij} = f(\mathbf{a}_i)^\top f(\mathbf{b}_j) eij=f(ai)⊤f(bj)
这里有一个巧妙的“分解”技巧:函数 f f f 分别作用于 a i \mathbf{a}_i ai 和 b j \mathbf{b}_j bj ,而不是将它们配对作为输入。这使得计算复杂度从 O ( m n ) O(mn) O(mn) 降低到了 O ( m + n ) O(m+n) O(m+n),大大提升了效率。
在计算出所有词元对之间的注意力分数后,我们使用Softmax进行归一化,从而得到一个序列对另一个序列的加权平均表示。具体来说:
- β i \beta_i βi:对于前提中的每个词元 a i \mathbf{a}_i ai,我们计算它与假设中所有词元对齐后的表示,即假设序列的加权平均。
- α j \alpha_j αj:对于假设中的每个词元 b j \mathbf{b}_j bj,我们计算它与前提中所有词元对齐后的表示,即前提序列的加权平均。
下面是实现这个过程的代码。首先,我们定义一个通用的 mlp
函数,它将作为我们模型中的基本构建块。
MLP 辅助函数
import torch
import torch.nn as nndef mlp(num_inputs, num_hiddens, flatten):"""构建多层感知机(MLP)网络参数:num_inputs (int): 输入特征维度num_hiddens (int): 隐藏层神经元数量flatten (bool): 是否在激活函数后进行展平操作返回:nn.Sequential: 构建好的MLP网络模型"""# 创建网络层列表,用于存储各个网络层net = []# 添加第一个Dropout层,防止过拟合net.append(nn.Dropout(0.2))# 添加第一个全连接层,将输入特征映射到隐藏层net.append(nn.Linear(num_inputs, num_hiddens))# 添加ReLU激活函数,引入非线性net.append(nn.ReLU())# 根据flatten参数决定是否添加展平层if flatten:# 将多维张量展平为一维,从第1维开始展平(保留batch维度)net.append(nn.Flatten(start_dim=1))# 添加第二个Dropout层,继续防止过拟合net.append(nn.Dropout(0.2))# 添加第二个全连接层,隐藏层到隐藏层的映射net.append(nn.Linear(num_hiddens, num_hiddens))# 添加第二个ReLU激活函数net.append(nn.ReLU())# 再次根据flatten参数决定是否添加展平层if flatten:# 将多维张量展平为一维net.append(nn.Flatten(start_dim=1))# 将所有网络层组合成Sequential模型并返回return nn.Sequential(*net)
Attend 模块
有了mlp
函数,我们就可以构建Attend
模块了。
import torch.nn.functional as Fclass Attend(nn.Module):"""注意力机制类,用于计算两个序列之间的软对齐实现论文中提到的注意力机制:e_ij = f(a_i)^T f(b_j)其中f是MLP网络,用于计算注意力权重"""def __init__(self, num_inputs, num_hiddens, **kwargs):"""初始化注意力机制参数:num_inputs (int): 输入特征维度(embed_size)num_hiddens (int): MLP隐藏层维度**kwargs: 传递给父类的其他参数"""super(Attend, self).__init__(**kwargs)# 创建MLP网络f,用于将输入序列映射到注意力空间# 输入维度: (batch_size, seq_len, num_inputs)# 输出维度: (batch_size, seq_len, num_hiddens)self.f = mlp(num_inputs, num_hiddens, flatten=False)def forward(self, A, B):"""前向传播,计算两个序列之间的软对齐参数:A (torch.Tensor): 序列A,形状为(batch_size, seq_A_len, embed_size)B (torch.Tensor): 序列B,形状为(batch_size, seq_B_len, embed_size)返回:beta (torch.Tensor): 序列B对序列A的软对齐,形状为(batch_size, seq_A_len, embed_size)alpha (torch.Tensor): 序列A对序列B的软对齐,形状为(batch_size, seq_B_len, embed_size)"""# 通过MLP网络f处理输入序列A和B# f_A的形状:(batch_size, seq_A_len, num_hiddens)# f_B的形状:(batch_size, seq_B_len, num_hiddens)f_A = self.f(A)f_B = self.f(B)# 计算注意力得分矩阵e# e的形状:(batch_size, seq_A_len, seq_B_len)e = torch.bmm(f_A, f_B.permute(0, 2, 1))# 计算beta:序列B被软对齐到序列A的每个词元# 对最后一维(seq_B_len)进行softmax,得到序列A中每个词元对序列B中所有词元的注意力权重# beta的形状:(batch_size, seq_A_len, embed_size)beta = torch.bmm(F.softmax(e, dim=-1), B)# 计算alpha:序列A被软对齐到序列B的每个词元# 对e进行转置后,对最后一维(seq_A_len)进行softmax# alpha的形状:(batch_size, seq_B_l