多头自注意力机制—Transformer模型的并行特征捕获引擎
作为深度学习领域的革命性突破,
Transformer
模型凭借其卓越的建模能力,已成为自然语言处理(NLP
)的主流架构。其中,多头自注意力机制(Multi-Head Self-Attention
)作为其核心创新组件,通过并行处理多个语义子空间,使模型能够高效捕捉输入序列的全局依赖关系,从而显著提升了特征提取能力。
Transformers-论文+源码
:https://download.csdn.net/download/m0_69402477/90861280
1. Transformer输入表示流程
以中文句子 “我喜欢深度学习” 为例
-
Tokenization :将句子切分为token
"我喜欢深度学习" → ["我", "喜欢", "深度学习"]
-
Token to ID:将token映射为数字ID
["我", "喜欢", "深度学习"] → [259, 372, 5892]
-
Embedding:将ID转换为词向量,形成一个shape为
(seq_len, d_model)
的矩阵d_model
是模型维度,通常设为512- 每个token都被映射到一个512维的高维空间中
-
位置编码(Positional Encoding):为每个token添加位置信息,得到最终输入表示 X
- 整体输入形状为:
(batch_size, seq_len, d_model)
- 整体输入形状为:
因此,Transformer的输入是一个三维张量,shape为:
(batch_size, seq_len, d_model)# batch_size就是一次输入几个。
# seq_len就是句子长度。
# d_model为当前模型的维度。
例如:(1, 3, 512)
表示一次输入一句话,句子长度为3,每词用512维表示。
2. 单头注意力机制简介
在标准的点积注意力中,我们通过线性变换生成 Query (Q)、Key (K)、Value (V):
Q = X W Q , K = X W K , V = X W V Q = XW^Q,\quad K = XW^K,\quad V = XW^V Q=XWQ,K=XWK,V=XWV
然后计算注意力权重:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
这种单头注意力只能从单一角度关注输入中的相关性,容易遗漏其他潜在的重要语义模式。
3. 多头注意力机制详解
多头注意力机制(Multi-Head Attention)是Transformer模型的核心组件之一,它通过并行地学习多种语义特征,极大地增强了模型的表达能力和泛化能力。多头注意力机制可以直观地理解为“多个脑袋同时关注不同的事情”,从而从全局角度捕捉更全面的信息。以下是多头注意力机制的详细解析。
3.1 核心思想
多头注意力机制的核心思想是通过将Query( Q Q Q)、Key ( K K K)、Value ( V V V) 投影到不同的子空间(subspace),使模型能够并行学习多种语义特征。具体来说,多头注意力机制通过以下四个步骤实现:
-
线性变换生成 Q Q Q、 K K K、 V V V
对于输入 X ∈ R s e q _ l e n × d m o d e l X \in \mathbb{R}^{seq\_len \times d_{model}} X∈Rseq_len×dmodel,通过三组权重矩阵 W q i , W k i , W v i W_q^i, W_k^i, W_v^i Wqi,Wki,Wvi,分别计算每个头的 Q i Q^i Qi、 K i K^i Ki、 V i V^i Vi:
Q i = X W q i , K i = X W k i , V i = X W v i Q^i = X W_q^i, \quad K^i = X W_k^i, \quad V^i = X W_v^i Qi=XWqi,Ki=XWki,Vi=XWvi
其中:- W q i , W k i , W v i ∈ R d m o d e l × d k W_q^i, W_k^i, W_v^i \in \mathbb{R}^{d_{model} \times d_k} Wqi,Wki,Wvi∈Rdmodel×dk, d k = d m o d e l / h d_k = d_{model} / h dk=dmodel/h, h h h 是头的数量。
- 每个头的维度 d k d_k dk 是模型维度 d m o d e l d_{model} dmodel 的一部分,确保每个头专注于不同的子空间。
-
分别计算每个头的注意力
对于每个头 i i i,独立计算注意力得分:
h e a d i = Attention ( Q i , K i , V i ) head_i = \text{Attention}(Q^i, K^i, V^i) headi=Attention(Qi,Ki,Vi)
注意力计算公式为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dkQKT)V
其中:- Q i , K i , V i ∈ R s e q _ l e n × d k Q^i, K^i, V^i \in \mathbb{R}^{seq\_len \times d_k} Qi,Ki,Vi∈Rseq_len×dk。
- 每个头独立计算注意力权重,从而捕捉不同的语义特征。
-
投影回原空间
将所有头的输出拼接起来,并通过一个线性变换 W O W^O WO 投影回原空间:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , … , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, \dots, head_h) W^O MultiHead(Q,K,V)=Concat(head1,…,headh)WO
其中:- W O ∈ R ( h ⋅ d k ) × d m o d e l W^O \in \mathbb{R}^{(h \cdot d_k) \times d_{model}} WO∈R(h⋅dk)×dmodel,用于将拼接后的输出投影回原始维度 d m o d e l d_{model} dmodel。
-
拼接所有头的输出
最终,我们将所有头的输出拼接起来,形成多头注意力的最终输出:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , … , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, \dots, head_h) W^O MultiHead(Q,K,V)=Concat(head1,…,headh)WO
3.2 数学表示
多头注意力的完整数学表示如下:
-
线性变换生成 Q Q Q、 K K K、 V V V:
Q i = X W q i , K i = X W k i , V i = X W v i , i = 1 , 2 , … , h Q^i = X W_q^i, \quad K^i = X W_k^i, \quad V^i = X W_v^i, \quad i = 1, 2, \dots, h Qi=XWqi,Ki=XWki,Vi=XWvi,i=1,2,…,h
其中:- W q i , W k i , W v i ∈ R d m o d e l × d k W_q^i, W_k^i, W_v^i \in \mathbb{R}^{d_{model} \times d_k} Wqi,Wki,Wvi∈Rdmodel×dk, d k = d m o d e l / h d_k = d_{model} / h dk=dmodel/h。
-
计算每个头的注意力:
h e a d i = Attention ( Q i , K i , V i ) = softmax ( Q i K i T d k ) V i head_i = \text{Attention}(Q^i, K^i, V^i) = \text{softmax}\left(\frac{Q^i K^{i^T}}{\sqrt{d_k}}\right) V^i headi=Attention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi -
拼接所有头的输出并投影回原空间:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , … , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, \dots, head_h) W^O MultiHead(Q,K,V)=Concat(head1,…,headh)WO
其中:- Concat ( h e a d 1 , … , h e a d h ) ∈ R s e q _ l e n × ( h ⋅ d k ) \text{Concat}(head_1, \dots, head_h) \in \mathbb{R}^{seq\_len \times (h \cdot d_k)} Concat(head1,…,headh)∈Rseq_len×(h⋅dk)。
- W O ∈ R ( h ⋅ d k ) × d m o d e l W^O \in \mathbb{R}^{(h \cdot d_k) \times d_{model}} WO∈R(h⋅dk)×dmodel。
3.3 理论与实践的差异
在实际实现中,多头注意力的计算方式与理论描述略有不同。为了提高计算效率,我们通常不会为每个头单独维护权重矩阵 W q i , W k i , W v i W_q^i, W_k^i, W_v^i Wqi,Wki,Wvi,而是通过一个完整的大矩阵来实现多头注意力。具体来说:
- 权重矩阵的合并:
- W q ∈ R d m o d e l × ( h ⋅ d k ) W_q \in \mathbb{R}^{d_{model} \times (h \cdot d_k)} Wq∈Rdmodel×(h⋅dk),用于生成所有头的 Q Q Q。
- W k ∈ R d m o d e l × ( h ⋅ d k ) W_k \in \mathbb{R}^{d_{model} \times (h \cdot d_k)} Wk∈Rdmodel×(h⋅dk),用于生成所有头的 K K K。
- W v ∈ R d m o d e l × ( h ⋅ d k ) W_v \in \mathbb{R}^{d_{model} \times (h \cdot d_k)} Wv∈Rdmodel×(h⋅dk),用于生成所有头的 V V V。
- 计算流程:
- 通过大矩阵计算所有头的 Q Q Q、 K K K、 V V V:
Q = X W q , K = X W k , V = X W v Q = X W_q, \quad K = X W_k, \quad V = X W_v Q=XWq,K=XWk,V=XWv
其中:- Q , K , V ∈ R s e q _ l e n × ( h ⋅ d k ) Q, K, V \in \mathbb{R}^{seq\_len \times (h \cdot d_k)} Q,K,V∈Rseq_len×(h⋅dk)。
- 将 Q Q Q、 K K K、 V V V 拆分为 h h h个头:
Q = split_heads ( Q ) , K = split_heads ( K ) , V = split_heads ( V ) Q = \text{split\_heads}(Q), \quad K = \text{split\_heads}(K), \quad V = \text{split\_heads}(V) Q=split_heads(Q),K=split_heads(K),V=split_heads(V)
拆分后,每个头的形状为:
Q i , K i , V i ∈ R s e q _ l e n × d k Q^i, K^i, V^i \in \mathbb{R}^{seq\_len \times d_k} Qi,Ki,Vi∈Rseq_len×dk - 分别计算每个头的注意力:
h e a d i = Attention ( Q i , K i , V i ) head_i = \text{Attention}(Q^i, K^i, V^i) headi=Attention(Qi,Ki,Vi) - 拼接所有头的输出并投影回原空间:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , … , h e a d h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(head_1, \dots, head_h) W^O MultiHead(Q,K,V)=Concat(head1,…,headh)WO
- 通过大矩阵计算所有头的 Q Q Q、 K K K、 V V V:
3.4多头注意力的优势
多头注意力机制通过并行学习多种子空间特征,具有以下显著优势:
- 并行性:多个头并行工作,提升计算效率。
- 多样性:不同头学习不同的语义模式,增强模型的表达能力。
- 鲁棒性:冗余设计提高模型的容错性。
- 可解释性:可以通过可视化注意力权重,分析模型关注的重点。
如何运用多头注意力机制完整流程图如下:
4. 代码实现解析
import torch
import torch.nn as nn
import torch.nn.functional as FclassMultiHeadAttention(nn.Module):def__init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)defsplit_heads(self, x):batch_size, seq_len, _ = x.size()x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)return x.transpose(1, 2)defscaled_dot_product_attention(self, Q, K, V, mask=None):scores = torch.matmul(Q, K.transpose(-2, -1))scores = scores / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))if mask isnotNone:scores = scores.masked_fill(mask == 0, -1e9)attention_weights = F.softmax(scores, dim=-1)output = torch.matmul(attention_weights, V)return output, attention_weightsdefforward(self, query, key, value, mask=None):"""前向传播过程query/key/value: 输入张量,形状均为(batch_size, seq_len, d_model)mask: 可选的掩码张量返回:输出张量和注意力权重"""# 重点关注forward里和理论不同的部分# 1. 线性变换生成Q/K/VQ = self.W_q(query) # (batch_size, seq_len, d_model)K = self.W_k(key)V = self.W_v(value)# 2. 分割为多个头的表示Q = self.split_heads(Q) # (batch_size, num_heads, seq_len, head_dim)K = self.split_heads(K)V = self.split_heads(V)# 3. 计算多头注意力attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)# 4. 合并多个头的输出# 先转置回(batch_size, seq_len, num_heads, head_dim)attention_output = attention_output.transpose(1, 2)# 合并最后一个维度(num_heads * head_dim = d_model)batch_size, seq_len, _, _ = attention_output.size()concat_output = attention_output.contiguous().view(batch_size, seq_len, self.d_model)# 5. 最终的线性变换(W_o)output = self.W_o(concat_output) # (batch_size, seq_len, d_model)return output, attention_weightsif __name__ == "__main__":batch_size = 2seq_len = 10d_model = 512num_heads = 8query = torch.randn(batch_size, seq_len, d_model)key = torch.randn(batch_size, seq_len, d_model)value = torch.randn(batch_size, seq_len, d_model)mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)output, attn_weights = mha(query, key, value)print("输入形状:", query.shape)print("输出形状:", output.shape)print("注意力权重形状:", attn_weights.shape)
5. 为什么多头可以代表多种子语义?
多头注意力机制的核心目标是通过并行地学习不同的子空间特征,从而捕捉到输入序列中的多种语义模式。这种能力的实现依赖于多个关键机制,包括参数独立性、非线性计算以及损失函数的隐式正则化等。以下是详细的分析:
5.1 参数独立性
每个注意力头都有独立的参数矩阵 W i Q , W i K , W i V W^Q_i, W^K_i, W^V_i WiQ,WiK,WiV,这些参数在初始化时是随机生成的,且彼此之间没有共享。由于随机初始化的差异性,不同头的参数矩阵从一开始就处于不同的初始状态。
- 随机初始化:每个头的参数矩阵 W i Q , W i K , W i V W^Q_i, W^K_i, W^V_i WiQ,WiK,WiV是独立初始化的,这意味着它们在训练开始时就已经具备了学习不同特征的潜力。
- 梯度更新的独立性:在反向传播过程中,每个头的参数矩阵会根据其自身的梯度信号进行更新,而不会受到其他头的影响。这种独立性确保了不同头的学习路径不会完全重叠。
例如: - 头 H 1 H_1 H1的参数矩阵 W 1 Q , W 1 K , W 1 V W^Q_1, W^K_1, W^V_1 W1Q,W1K,W1V可能倾向于捕捉局部的短距离依赖关系(如词与词之间的直接关联)。
- 头 H 2 H_2 H2的参数矩阵 W 2 Q , W 2 K , W 2 V W^Q_2, W^K_2, W^V_2 W2Q,W2K,W2V则可能专注于长程依赖关系(如句子中跨多个词的上下文信息)。
5.2 梯度多样性
在多头注意力机制中,每个头的参数矩阵 W i Q , W i K , W i V W^Q_i, W^K_i, W^V_i WiQ,WiK,WiV都会接收到独立的梯度信号。这种梯度多样性是优化过程的关键,它迫使不同头的学习方向逐渐分化。
5.2.1 梯度计算公式
对于第 i i i个头的Query 矩阵 Q i Q^i Qi,其参数矩阵 W i Q W^Q_i WiQ的梯度为:
∇ W i Q L = ∂ L ∂ W i Q = ∂ L ∂ head i ⋅ ∂ head i ∂ Q i ⋅ ∂ Q i ∂ W i Q \nabla_{W^Q_i} \mathcal{L} = \frac{\partial \mathcal{L}}{\partial W^Q_i} = \frac{\partial \mathcal{L}}{\partial \text{head}_i} \cdot \frac{\partial \text{head}_i}{\partial Q^i} \cdot \frac{\partial Q^i}{\partial W^Q_i} ∇WiQL=∂WiQ∂L=∂headi∂L⋅∂Qi∂headi⋅∂WiQ∂Qi
- 损失函数对头的敏感性:不同头的输出 head i \text{head}_i headi 对损失函数 L \mathcal{L} L 的贡献不同,这会导致每个头的梯度 ∇ W i Q L \nabla_{W^Q_i} \mathcal{L} ∇WiQL具有不同的分布。
- 注意力权重的非线性性:注意力权重的计算公式包含softmax函数:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
softmax的指数运算会放大某些位置的权重,同时抑制其他位置的权重,这种非线性特性进一步增强了梯度的多样性。
5.2.2 梯度多样性的结果
- 不同头的梯度信号会引导它们学习不同的特征模式。
- 如果某个头的输出对当前任务的损失贡献较大,那么它的参数矩阵会接收更强的梯度信号,从而更快速地调整以适应任务需求。
5.3 注意力权重计算的非线性性
注意力机制的核心公式是:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中,softmax函数具有显著的非线性特性:
- 指数放大效应:softmax函数会对分数较高的位置赋予更高的权重,而分数较低的位置权重会被显著抑制。
- 竞争机制:在计算注意力权重时,不同位置之间的权重分配是相互竞争的,这导致每个头的注意力分布具有独特性。
例如:
- 头 H 1 H_1 H1可能关注句子中的高频词汇或局部结构。
- 头 H 2 H_2 H2可能关注稀疏但重要的长程依赖关系。
这种非线性计算使得每个头的注意力分布呈现出不同的模式,进一步增强了多头注意力的多样化特征捕获能力。
5.4 损失函数的隐式正则化
在多头注意力机制中,如果多个头的学习方向高度一致,那么模型的参数更新也会趋于一致,这可能导致梯度消失或冗余问题。为了避免这种情况,损失函数会自动调整,促使不同头的学习内容差异化。
- 避免冗余:如果多个头学习相似的内容,它们的梯度会趋于一致,导致模型无法充分利用多头的优势。损失函数会通过优化过程惩罚这种冗余,促使不同头学习互补的特征。
- 促进多样性:损失函数会优先选择那些能够有效降低整体损失的头,从而自然地推动不同头的学习方向分化。
5.5 头之间的正交性
随着训练的推进,不同头的参数矩阵 W i Q , W i K , W i V W^Q_i, W^K_i, W^V_i WiQ,WiK,WiV会逐渐展现出正交性。这种正交性表明不同头在不同的子空间中学习到了互补的信息。
- 正交性的来源:当不同头的梯度方向和更新路径不同时,它们的参数矩阵会逐渐远离彼此,形成正交关系。
- 意义:正交性说明每个头都在探索一个独特的语义子空间,从而实现了对输入数据的全面覆盖。
总结下来可以归纳出重要的三点:
- 参数独立性 :每个头的参数矩阵独立初始化,且各自接收独立的梯度信号,保证了学习路径的多样性。
- 注意力权重的非线性计算:softmax函数的指数运算放大了特定部分的权重,抑制了其他部分,进一步增强了头之间的差异性。
- 损失函数的隐式正则化 :损失函数会自动调整,避免多个头学习相同的内容,促使它们学习互补的特征。
6. 举个栗子:局部vs长程依赖
假设我们有两个头 H 1 H_1 H1和 H 2 H_2 H2:
- H 1 H_1 H1:专注于局部结构,例如识别短语 “尊贵的X1车主”
- 这个头可能学习到词与词之间的直接关联,例如“尊贵的”和“车主”的局部关系。
- H 2 H_2 H2:捕捉长程依赖,例如理解“他最喜欢的车是保时捷,但准备买的是X5,最终妥协买了X1”
- 这个头可能关注句子中的跨句信息,例如“保时捷”、“X5”和“X1”之间的长程依赖。
当预测下一个词时,如果局部结构不足以支撑预测,则损失函数会对 H 2 H_2 H2发出更强的梯度信号,迫使它更注重跨句信息的学习。这种机制确保了不同头能够分别关注不同的语义模式。
7. 头的数量如何选定?
选择头的数量是一个综合考虑计算效率、特征多样性和维度分配的过程:
- 经验性调优:通常通过实验验证不同头数量在不同任务下的表现,选择性能最佳的配置。
- 维度分配原则:每个头的维度需要足够大以捕获有效信息,常见设置为每个头64或128维。
- 计算效率:增加头的数量可以提升模型的并行性,但也增加了计算开销。
常见的模型头数量如下:- transformer: 8头,每个头64维,总维度512- BERT-base: 12头,每个头64维,总维度768- BERT-large:16头,每个头64维,总维度1024- GPT-3 175B:96头,每个头128维,总维度12288头数量分配原则:- 每个头的维度需要足够大以捕获有效信息,从过往经验来看通常是≥64维- 经验性调优:测试不同的头在不同任务下效果,选择性能最佳。- 计算效率:头数量增加可以提升模型的并行性。
8. 输入维度的变化
9. 结语
多头注意力机制不仅是Transformer架构的基石,更是现代大语言模型(LLM)成功的关键所在。通过并行地学习多种语义子空间,它实现了对复杂语言结构的高效建模,为NLP任务带来了革命性的突破,掌握其背后的数学原理与工程实现,对于构建、优化甚至解释Transformer类模型具有重要意义。