当前位置: 首页 > ds >正文

自注意力,多头注意力,交叉注意力代码对比

自注意力、多头注意力与交叉注意力的PyTorch代码对比

1. 自注意力 (Self-Attention)

import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dim# 投影矩阵:Q/K/V共享输入维度self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.softmax = nn.Softmax(dim=-1)def forward(self, x):"""x: (batch_size, seq_len, embed_dim)"""# 1. 生成Q/K/V - 全部来自同一输入Q = self.query(x)  # (B, L, D)K = self.key(x)    # (B, L, D)V = self.value(x)  # (B, L, D)# 2. 计算注意力分数attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))# 3. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, L, L)# 4. 加权求和output = torch.matmul(attn_weights, V)  # (B, L, D)return output

核心特征

  • Q/K/V全部来自同一个输入序列
  • 注意力分数矩阵维度为(L, L),表示序列内部的关系
  • 输出序列长度和维度不变

2. 多头注意力 (Multi-Head Attention)

class MultiHeadAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_heads# 确保可分割assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by num_heads"# 多头投影矩阵self.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)# 输出层self.fc_out = nn.Linear(embed_dim, embed_dim)self.softmax = nn.Softmax(dim=-1)def split_heads(self, x):"""分割为多头"""batch_size = x.size(0)# (B, L, D) -> (B, L, H, HD) -> (B, H, L, HD)return x.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)def forward(self, x):"""多头自注意力"""# 1. 生成Q/K/VQ = self.query(x)K = self.key(x)V = self.value(x)# 2. 分割为多头Q = self.split_heads(Q)  # (B, H, L, HD)K = self.split_heads(K)V = self.split_heads(V)# 3. 计算注意力分数attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))# 4. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, H, L, L)# 5. 加权求和attention = torch.matmul(attn_weights, V)  # (B, H, L, HD)# 6. 合并多头attention = attention.transpose(1, 2).contiguous()  # (B, L, H, HD)attention = attention.view(attention.size(0), -1, self.embed_dim)  # (B, L, D)# 7. 输出投影output = self.fc_out(attention)return output

核心特征

  • 基于自注意力扩展
  • 额外的分割(head splitting)和合并操作
  • 每个头在降维后的子空间(HD)中计算
  • 最终通过全连接层融合多头信息

3. 交叉注意力 (Cross-Attention)

class CrossAttention(nn.Module):def __init__(self, embed_dim):super().__init__()self.embed_dim = embed_dim# Query来自序列A,Key/Value来自序列Bself.query = nn.Linear(embed_dim, embed_dim)  # for sequence Aself.key = nn.Linear(embed_dim, embed_dim)   # for sequence Bself.value = nn.Linear(embed_dim, embed_dim) # for sequence Bself.softmax = nn.Softmax(dim=-1)def forward(self, x_a, x_b):"""x_a: (batch_size, len_a, embed_dim)  序列Ax_b: (batch_size, len_b, embed_dim)  序列B"""# 1. 生成Q/K/V - 来自不同输入源Q = self.query(x_a)   # 来自序列A (B, La, D)K = self.key(x_b)     # 来自序列B (B, Lb, D)V = self.value(x_b)   # 来自序列B (B, Lb, D)# 2. 计算注意力分数 (序列A到序列B的映射)attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embed_dim))# 3. 注意力权重归一化attn_weights = self.softmax(attn_scores)  # (B, La, Lb)# 4. 加权求和output = torch.matmul(attn_weights, V)  # (B, La, D)return output

核心特征

  • Q来自一个序列,K/V来自另一个序列
  • 注意力矩阵维度为(La, Lb),表示序列间关系
  • 输出序列长度与查询序列相同(La),维度不变
  • 不要求两个序列长度相同

三者的核心对比

特性自注意力多头注意力交叉注意力
输入序列数量1个1个2个
Q来源自身自身序列A
K/V来源自身自身序列B
维度变换分割头+合并
注意力矩阵(L, L)(H, L, L)(La, Lb)
输出长度LLLa
主要用途序列内关系多角度特征提取序列间关系建模
计算复杂度O(L²·D)O(H·L²·HD)O(La·Lb·D)

使用场景示例

# 示例:序列长度均为5,嵌入维度128
x = torch.randn(2, 5, 128)  # batch_size=2, seq_len=5, embed_dim=128
y = torch.randn(2, 3, 128)  # 不同长度序列# 1. 自注意力
self_attn = SelfAttention(embed_dim=128)
output_self = self_attn(x)  # (2, 5, 128)# 2. 多头注意力 (8头)
multihead_attn = MultiHeadAttention(embed_dim=128, num_heads=8)
output_multi = multihead_attn(x)  # (2, 5, 128)# 3. 交叉注意力
cross_attn = CrossAttention(embed_dim=128)
output_cross = cross_attn(x, y)  # (2, 5, 128) - 保持查询序列长度

性能优化技巧

  1. 融合计算:现代PyTorch版本推荐使用优化API

    # PyTorch 1.12+ 优化实现
    output = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
    
  2. 内存优化:使用计算过程重算减少内存占用

    with torch.cuda.amp.autocast(enabled=True):output = some_attention(Q, K, V)
    
  3. 稀疏注意力:对大序列使用稀疏矩阵

    from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
    
http://www.xdnf.cn/news/12073.html

相关文章:

  • 【AI学习笔记】Coze工作流写入飞书多维表格(即:多维表格飞书官方插件使用教程)
  • Jenkins的学习与使用(CI/CD)
  • Postgresql常规SQL语句操作
  • 基于Web的安全漏洞分析与修复平台设计与实现
  • 力扣面试150题--岛屿数量
  • 【计算机网络】网络层协议
  • 【vue3学习】vue3入门
  • C++ 变量三
  • [JS逆向] 烯牛数据
  • Spring Boot微服务架构(十):Docker与K8S部署的区别
  • 5090cuda_torch
  • Python训练打卡Day42
  • 前端面试真题(第一集)
  • 解决pycharm同一个文件夹下from *** import***仍显示No module named
  • 结构性设计模式之Facade(外观)设计模式
  • 34.2STM32下的can总线外设_csdn
  • 修改 Windows 10/11 的系统设置中显示的安装日期
  • CMake入门:1、环境搭建
  • 防火墙设置实战操作案例(小白的“升级打怪”成长之路)
  • FreeType 字体信息检查工具 - 现代C++实现
  • selenium学习实战【Python爬虫】
  • 【贪心、DP、线段树优化】Leetcode 376. 摆动序列
  • 当AI遇上防火墙:新一代智能安全解决方案全景解析
  • Elasticsearch中的自定义分析器(Custom Analyzer)介绍
  • 2025最新Java日志框架深度解析:Log4j 2 vs Logback性能实测+企业级实战案例
  • 一个完整的时间序列异常检测系统,使用Flask作为后端框架,实现了AE(自编码器)、TimesNet和LSTM三种模型,并提供可视化展示
  • Asp.Net Core基于StackExchange Redis 缓存
  • 使用TypeScript构建一个最简单的MCP服务器
  • PDF处理控件Aspose.PDF教程:在 C# 中更改 PDF 页面大小
  • 【从零学习JVM|第二篇】字节码文件