当前位置: 首页 > ai >正文

【现代深度学习技术】注意力机制05:多头注意力

在这里插入图片描述

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ PyTorch深度学习 ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。

文章目录

    • 一、模型
    • 二、实现
    • 小结


  在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。

  为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h h h组不同的线性投影(linear projections)来变换查询、键和值。然后,这 h h h组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h h h个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h h h个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。图1展示了使用全连接层来实现可学习的线性变换的多头注意力。

在这里插入图片描述

图1 多头注意力:多个头连结然后线性变换

一、模型

  在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} qRdq、键 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} kRdk和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} vRdv,每个注意力头 h i \mathbf{h}_i hi i = 1 , … , h i = 1, \ldots, h i=1,,h)的计算方法为:
h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v (1) \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v} \tag{1} hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rpv(1) 其中,可学习的参数包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)Rpq×dq W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)Rpk×dk W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)Rpv×dv,以及代表注意力汇聚的函数 f f f f f f可以是注意力评分函数中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} WoRpo×hpv
W o [ h 1 ⋮ h h ] ∈ R p o (2) \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o} \tag{2} Wo h1hh Rpo(2)

  基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

二、实现

  在实现过程中通常选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pqh=pkh=pvh=po,则可以并行计算 h h h个头。在下面的实现中, p o p_o po是通过参数num_hiddens指定的。

#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

  为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

  下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
attention.eval()

在这里插入图片描述

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

在这里插入图片描述

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。
http://www.xdnf.cn/news/5480.html

相关文章:

  • c#修改ComboBox当前选中项的文本
  • python生成八位密码包含字母和数字
  • 构建DEEPPOLAR ——Architecture for DEEPPOLAR (256,37)
  • 智能手表测试用例文档
  • 医学影像处理与可视化:从预处理到 3D 重建的技术实践
  • 无偿帮写毕业论文
  • 工业4.0时代下的人工智能新发展
  • C++跨平台开发概述
  • Python基础语法(中)
  • C++ stl中的priority_queue的相关函数用法
  • 基于Docker的Bitwarden的私有本地部署
  • Linux 进程控制 基础IO
  • 精读计算机体系结构基础 第三章 特权指令系统
  • 接口的基础定义与属性约束
  • 操作系统 : 线程同步与互斥
  • [Java实战]springboot注解@ControllerAdvice解析(十二)
  • DeepSeek模型微调指南:解锁高级技术,引领AI新变革
  • 信息系统项目管理师-软考高级(软考高项)​​​​​​​​​​​2025最新(十五)
  • DeepSeek:开启能源领域智能化变革新时代
  • MySQL索引概述
  • C/C++复习--C语言隐式类型转换
  • stm32 WDG看门狗
  • MySQL数据库常见面试题之三大范式
  • Python打卡训练营Day22
  • 并发笔记-锁(一)
  • AI 小智代码架构分析
  • 【DNDC模型】双碳目标下DNDC模型建模方法及在土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的应用
  • Java 中 AQS 的实现原理
  • Problem B: 面向对象综合题2
  • [思维模式-27]:《本质思考力》-7- 逆向思考的原理与应用