当前位置: 首页 > news >正文

【现代深度学习技术】注意力机制04:Bahdanau注意力

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、模型
    • 二、定义注意力解码器
    • 三、训练
    • 小结


  序列到序列学习(seq2seq)中探讨了机器翻译问题:通过设计一个基于两个循环神经网络的编码器-解码器架构,用于序列到序列学习。具体来说,循环神经网络编码器将长度可变的序列转换为固定形状的上下文变量,然后循环神经网络解码器根据生成的词元和上下文变量按词元生成输出(目标)序列词元。然而,即使并非所有输入(源)词元都对解码某个词元都有用,在每个解码步骤中仍使用编码相同的上下文变量。有什么方法能改变上下文变量呢?

  我们试着找到灵感:在为给定文本序列生成手写的挑战中,Graves设计了一种可微注意力模型,将文本字符与更长的笔迹对齐,其中对齐方式仅向一个方向移动。受学习对齐想法的启发,Bahdanau等人提出了一个没有严格单向对齐限制的可微注意力模型。在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。这是通过将上下文变量视为注意力集中的输出来实现的。

一、模型

  下面描述的Bahdanau注意力模型将遵循序列到序列学习(seq2seq)中的相同符号表达。这个新的基于注意力的模型与序列到序列学习(seq2seq)中的模型相同,只不过其中式(3)中的上下文变量 c \mathbf{c} c在任何解码时间步 t ′ t' t都会被 c t ′ \mathbf{c}_{t'} ct替换。假设输入序列中有 T T T个词元,解码时间步 t ′ t' t的上下文变量是注意力集中的输出:
c t ′ = ∑ t = 1 T α ( s t ′ − 1 , h t ) h t (1) \mathbf{c}_{t'} = \sum_{t=1}^T \alpha(\mathbf{s}_{t' - 1}, \mathbf{h}_t) \mathbf{h}_t \tag{1} ct=t=1Tα(st1,ht)ht(1) 其中,时间步 t ′ − 1 t' - 1 t1时的解码器隐状态 s t ′ − 1 \mathbf{s}_{t' - 1} st1是查询,编码器隐状态 h t \mathbf{h}_t ht既是键,也是值,注意力权重 α \alpha α是使用加性注意力打分函数计算的。

  与循环神经网络编码器-解码器架构略有不同,图1描述了Bahdanau注意力的架构。

在这里插入图片描述

图1 一个带有Bahdanau注意力的循环神经网络编码器-解码器模型

import torch
from torch import nn
from d2l import torch as d2l

二、定义注意力解码器

  下面看看如何定义Bahdanau注意力,实现循环神经网络编码器-解码器。其实,我们只需重新定义解码器即可。为了更方便地显示学习的注意力权重,以下AttentionDecoder类定义了带有注意力机制解码器的基本接口。

#@save
class AttentionDecoder(d2l.Decoder):"""带有注意力机制解码器的基本接口"""def __init__(self, **kwargs):super(AttentionDecoder, self).__init__(**kwargs)@propertydef attention_weights(self):raise NotImplementedError

  接下来,让我们在接下来的Seq2SeqAttentionDecoder类中实现带有Bahdanau注意力的循环神经网络解码器。首先,初始化解码器的状态,需要下面的输入:

  1. 编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
  2. 上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
  3. 编码器有效长度(排除在注意力池中填充词元)。

  在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)self.attention = d2l.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, enc_valid_lens, *args):# outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,num_hiddens)outputs, hidden_state = enc_outputsreturn (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)def forward(self, X, state):# enc_outputs的形状为(batch_size,num_steps,num_hiddens).# hidden_state的形状为(num_layers,batch_size,# num_hiddens)enc_outputs, hidden_state, enc_valid_lens = state# 输出X的形状为(num_steps,batch_size,embed_size)X = self.embedding(X).permute(1, 0, 2)outputs, self._attention_weights = [], []for x in X:# query的形状为(batch_size,1,num_hiddens)query = torch.unsqueeze(hidden_state[-1], dim=1)# context的形状为(batch_size,1,num_hiddens)context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)# 在特征维度上连结x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)# 将x变形为(1,batch_size,embed_size+num_hiddens)out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)outputs.append(out)self._attention_weights.append(self.attention.attention_weights)# 全连接层变换后,outputs的形状为# (num_steps,batch_size,vocab_size)outputs = self.dense(torch.cat(outputs, dim=0))return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]@propertydef attention_weights(self):return self._attention_weights

  接下来,使用包含7个时间步的4个序列输入的小批量测试Bahdanau注意力解码器。

encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()
X = torch.zeros((4, 7), dtype=torch.long)  # (batch_size,num_steps)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape

在这里插入图片描述

三、训练

  与序列到序列学习(seq2seq)类似,我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器,并对这个模型进行机器翻译训练。由于新增的注意力机制,训练要序列到序列学习(seq2seq)比没有注意力机制的慢得多。

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
在这里插入图片描述

  模型训练后,我们用它将几个英语句子翻译成法语并计算它们的BLEU分数。

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, dec_attention_weight_seq = d2l.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device, True)print(f'{eng} => {translation}, ', f'bleu {d2l.bleu(translation, fra, k=2):.3f}')

在这里插入图片描述

attention_weights = torch.cat([step[0][0][0] for step in dec_attention_weight_seq], 0).reshape((1, 1, -1, num_steps))

  训练结束后,下面通过可视化注意力权重会发现,每个查询都会在键值对上分配不同的权重,这说明在每个解码步中,输入序列的不同部分被选择性地聚集在注意力池中。

# 加上一个包含序列结束词元
d2l.show_heatmaps(attention_weights[:, :, :, :len(engs[-1].split()) + 1].cpu(),xlabel='Key positions', ylabel='Query positions')

在这里插入图片描述

小结

  • 在预测词元时,如果不是所有输入词元都是相关的,那么具有Bahdanau注意力的循环神经网络编码器-解码器会有选择地统计输入序列的不同部分。这是通过将上下文变量视为加性注意力池化的输出来实现的。
  • 在循环神经网络编码器-解码器中,Bahdanau注意力将上一时间步的解码器隐状态视为查询,在所有时间步的编码器隐状态同时视为键和值。
http://www.xdnf.cn/news/369163.html

相关文章:

  • SwarmUI:基于.Net开发的开源AI 图像生成 Web 用户界面系统
  • GPT-4o, GPT 4.5, GPT 4.1, O3, O4-mini等模型的区别与联系
  • n8n系列(5):LangChain与大语言模型应用
  • Vue3 怎么在ElMessage消息提示组件中添加自定义icon图标
  • 【 Redis | 实战篇 缓存 】
  • VS小技巧:如何在一个项目中添加其他项目
  • 电位器如何接入西门子PLC的模拟量输入
  • 01 dnsmasq 中 dns服务
  • 【大模型面试每日一题】Day 13:数据并行与模型并行的区别是什么?ZeRO优化器如何结合二者?
  • 背单词软件开发英语App英语提分宝超级单词表,河南数匠软件开发
  • PCBA是电子设备的核心大脑!
  • node提示node:events:495 throw er解决方法
  • C语言编程--19.括号生成
  • 手动修改uart16550的FIFO深度?
  • STM32F103VE 三种低功耗模式
  • CN3791 锂电池充电芯片详解及电路设计要点-国产芯片
  • java-多态
  • 机舱巡飞平台技术要点突破点详解!
  • 流式渲染 Streaming SSR
  • deep seek简介和解析
  • BERT模型讲解
  • 【C语言指针超详解(三)】--数组名的理解,一维数组传参的本质,冒泡排序,二级指针,指针数组
  • 开平机:技术深水区与产业变革的融合突破
  • spring ai alibaba ChatClient 获取大模型返回内容的方式 以及使用场景
  • 什么是 HEIC 格式?如何在电脑上查看HEIC格式的图像?
  • 软件开发的图表类型
  • RAG优化知识库检索(1):基础概念与架构
  • 结构性变革与新兴机遇
  • 如何评估SAP升级实施商的专业能力?
  • JWT原理及工作流程详解