当前位置: 首页 > news >正文

从代码学习深度学习 - 自然语言推断:使用注意力 PyTorch版

文章目录

  • 前言
  • 模型详解
    • 第一步:注意(Attending)
      • MLP 辅助函数
      • Attend 模块
    • 第二步:比较(Comparing)
      • Compare 模块
    • 第三步:聚合(Aggregating)
      • Aggregate 模块
    • 整合模型:DecomposableAttention
  • 模型训练与评估
    • 数据准备
      • 数据处理工具集
      • 加载数据
    • 模型创建与初始化
      • 词向量与设备工具
      • 实例化与初始化
    • 训练过程
      • 训练工具集
      • 执行训练
    • 使用模型进行预测
  • 总结


前言

自然语言推断(Natural Language Inference, NLI)是自然语言处理(NLP)领域一个核心且富有挑战性的任务。它的目标是判断两句话之间的逻辑关系,通常分为蕴涵(Entailment)、矛盾(Contradiction)和中性(Neutral)三类。例如,给定前提“一个人在骑马”,我们希望模型能推断出假设“一个人在户外”是蕴涵关系,而“一个人在睡觉”是矛盾关系。

传统的NLI模型常常依赖于复杂的深度网络结构,如循环神经网络(RNN)或卷积神经网络(CNN)。然而,在2016年,Parikh等人提出了一种新颖且高效的“可分解注意力模型”(Decomposable Attention Model),该模型完全摒弃了循环和卷积层,仅通过注意力机制和简单的多层感知机(MLP),就在当时权威的SNLI数据集上取得了顶尖的性能,并且参数量更少。

这篇博客将带领大家深入探索这一经典模型的PyTorch实现。我们将从零开始,逐一剖析模型的三个核心步骤——注意(Attending)比较(Comparing)聚合(Aggregating),并通过详尽的代码和注释,帮助你彻底理解其工作原理。无论你是NLP初学者还是希望温故知新的开发者,相信本文都能为你带来启发。

在这里插入图片描述

完整代码:下载链接

模型详解

可分解注意力模型的整体思想非常直观:它不依赖于对句子时序信息的复杂编码,而是将一个句子中的每个词元与另一个句子中的所有词元进行“对齐”,然后比较这些对齐的信息,最后将所有比较结果汇总起来,做出最终的逻辑判断。

在这里插入图片描述

如上图所示,整个模型由三个联合训练的阶段构成:注意、比较和聚合。接下来,我们将逐一深入代码实现。

第一步:注意(Attending)

“注意”阶段的核心任务是建立前提(Premise)和假设(Hypothesis)中词元之间的“软对齐”(Soft Alignment)。例如,对于前提“我确实需要睡眠”和假设“我累了”,我们希望模型能够自动地将假设中的“我”与前提中的“我”对齐,并将“累”与“睡眠”对齐。这种对齐是通过注意力权重实现的。

我们用 A = ( a 1 , … , a m ) \mathbf{A} = (\mathbf{a}_1, \ldots, \mathbf{a}_m) A=(a1,,am) B = ( b 1 , … , b n ) \mathbf{B} = (\mathbf{b}_1, \ldots, \mathbf{b}_n) B=(b1,,bn) 分别表示前提和假设的词向量序列。首先,我们将每个词向量通过一个共享的MLP网络 f f f 进行变换。然后,前提中第 i i i 个词元和假设中第 j j j 个词元的注意力分数 e i j e_{ij} eij 计算如下:

e i j = f ( a i ) ⊤ f ( b j ) e_{ij} = f(\mathbf{a}_i)^\top f(\mathbf{b}_j) eij=f(ai)f(bj)

这里有一个巧妙的“分解”技巧:函数 f f f 分别作用于 a i \mathbf{a}_i ai b j \mathbf{b}_j bj ,而不是将它们配对作为输入。这使得计算复杂度从 O ( m n ) O(mn) O(mn) 降低到了 O ( m + n ) O(m+n) O(m+n),大大提升了效率。

在计算出所有词元对之间的注意力分数后,我们使用Softmax进行归一化,从而得到一个序列对另一个序列的加权平均表示。具体来说:

  1. β i \beta_i βi:对于前提中的每个词元 a i \mathbf{a}_i ai,我们计算它与假设中所有词元对齐后的表示,即假设序列的加权平均。
  2. α j \alpha_j αj:对于假设中的每个词元 b j \mathbf{b}_j bj,我们计算它与前提中所有词元对齐后的表示,即前提序列的加权平均。

下面是实现这个过程的代码。首先,我们定义一个通用的 mlp 函数,它将作为我们模型中的基本构建块。

MLP 辅助函数

import torch
import torch.nn as nndef mlp(num_inputs, num_hiddens, flatten):"""构建多层感知机(MLP)网络参数:num_inputs (int): 输入特征维度num_hiddens (int): 隐藏层神经元数量flatten (bool): 是否在激活函数后进行展平操作返回:nn.Sequential: 构建好的MLP网络模型"""# 创建网络层列表,用于存储各个网络层net = []# 添加第一个Dropout层,防止过拟合net.append(nn.Dropout(0.2))# 添加第一个全连接层,将输入特征映射到隐藏层net.append(nn.Linear(num_inputs, num_hiddens))# 添加ReLU激活函数,引入非线性net.append(nn.ReLU())# 根据flatten参数决定是否添加展平层if flatten:# 将多维张量展平为一维,从第1维开始展平(保留batch维度)net.append(nn.Flatten(start_dim=1))# 添加第二个Dropout层,继续防止过拟合net.append(nn.Dropout(0.2))# 添加第二个全连接层,隐藏层到隐藏层的映射net.append(nn.Linear(num_hiddens, num_hiddens))# 添加第二个ReLU激活函数net.append(nn.ReLU())# 再次根据flatten参数决定是否添加展平层if flatten:# 将多维张量展平为一维net.append(nn.Flatten(start_dim=1))# 将所有网络层组合成Sequential模型并返回return nn.Sequential(*net)

Attend 模块

有了mlp函数,我们就可以构建Attend模块了。

import torch.nn.functional as Fclass Attend(nn.Module):"""注意力机制类,用于计算两个序列之间的软对齐实现论文中提到的注意力机制:e_ij = f(a_i)^T f(b_j)其中f是MLP网络,用于计算注意力权重"""def __init__(self, num_inputs, num_hiddens, **kwargs):"""初始化注意力机制参数:num_inputs (int): 输入特征维度(embed_size)num_hiddens (int): MLP隐藏层维度**kwargs: 传递给父类的其他参数"""super(Attend, self).__init__(**kwargs)# 创建MLP网络f,用于将输入序列映射到注意力空间# 输入维度: (batch_size, seq_len, num_inputs)# 输出维度: (batch_size, seq_len, num_hiddens)self.f = mlp(num_inputs, num_hiddens, flatten=False)def forward(self, A, B):"""前向传播,计算两个序列之间的软对齐参数:A (torch.Tensor): 序列A,形状为(batch_size, seq_A_len, embed_size)B (torch.Tensor): 序列B,形状为(batch_size, seq_B_len, embed_size)返回:beta (torch.Tensor): 序列B对序列A的软对齐,形状为(batch_size, seq_A_len, embed_size)alpha (torch.Tensor): 序列A对序列B的软对齐,形状为(batch_size, seq_B_len, embed_size)"""# 通过MLP网络f处理输入序列A和B# f_A的形状:(batch_size, seq_A_len, num_hiddens)# f_B的形状:(batch_size, seq_B_len, num_hiddens)f_A = self.f(A)f_B = self.f(B)# 计算注意力得分矩阵e# e的形状:(batch_size, seq_A_len, seq_B_len)e = torch.bmm(f_A, f_B.permute(0, 2, 1))# 计算beta:序列B被软对齐到序列A的每个词元# 对最后一维(seq_B_len)进行softmax,得到序列A中每个词元对序列B中所有词元的注意力权重# beta的形状:(batch_size, seq_A_len, embed_size)beta = torch.bmm(F.softmax(e, dim=-1), B)# 计算alpha:序列A被软对齐到序列B的每个词元# 对e进行转置后,对最后一维(seq_A_len)进行softmax# alpha的形状:(batch_size, seq_B_l
http://www.xdnf.cn/news/1073665.html

相关文章:

  • 基于Servlet + Jsp 的在线考试系统
  • 华为云Flexus+DeepSeek征文 | 华为云 ModelArts Studio 赋能高情商AI聊天助手:用技术构建有温度的智能对话体验
  • libevent(2)之使用教程(1)介绍
  • 基于云的平板挠度模拟:动画与建模-AI云计算数值分析和代码验证
  • 多模态大语言模型arxiv论文略读(143)
  • 广度优先搜索BFS(广搜)复习(c++)
  • 深入理解Mysql索引底层数据结构和算法
  • NeRF-Lidar实景重建:大疆Mavic 4 Pro低成本建模方案(2025实战指南)
  • H3C-路由器DHCPV6V4配置标准
  • C++基础(FreeRDP编译)
  • SRS流媒体服务器之本地测试rtc推流bug
  • Python 数据分析:numpy,抽提,整数数组索引。听故事学知识点怎么这么容易?
  • 第八讲——一元函数积分学的概念与性质
  • 【编译原理】期末
  • 设备树引入
  • 【Java--SQL】${}与#{}区别和危害
  • 【EDA软件】【联合Modelsim 同步FIFO仿真】
  • git 挑选:git cherry-pick
  • springboot+Vue逍遥大药房管理系统
  • python中学物理实验模拟:瞬间推力与摩擦力作用下的物体运动
  • 【数据标注师】目标跟踪标注
  • 概述-4-通用语法及分类
  • Word之空白页删除2
  • 基于Pandas和FineBI的昆明职位数据分析与可视化实现(二)- 职位数据清洗与预处理
  • UniApp Vue3 模式下实现页面跳转的全面指南
  • SQL关键字三分钟入门:ROW_NUMBER() —— 窗口函数为每一行编号
  • FreeSWITCH配置文件解析(2) dialplan 拨号计划中xml 的action解析
  • 西门子S7-200 SMART PLC:小型自动化领域的高效之选
  • C语言---常见的字符函数和字符串函数介绍
  • STM32[笔记]--7.MDK5调试功能