当前位置: 首页 > ds >正文

Multi-Query Attention (MQA) PyTorch 实现

和多头注意力机制的唯一区别:K、V在不同的head之间实现了复用,而对于不同的头,Q依然不同。

因此这里的代码和标准多头注意力的实现也是几乎完全一样:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MultiQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads):super().__init__()self.num_heads = num_headsself.head_dim = embed_dim // num_headsself.scale = self.head_dim ** -0.5# 查询、键、值投影self.q_proj = nn.Linear(embed_dim, embed_dim)  # 多头查询self.k_proj = nn.Linear(embed_dim, self.head_dim)  # 单头键self.v_proj = nn.Linear(embed_dim, self.head_dim)  # 单头值self.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# 投影q = self.q_proj(x)  # (batch, seq_len, embed_dim)k = self.k_proj(x)  # (batch, seq_len, head_dim)v = self.v_proj(x)  # (batch, seq_len, head_dim)# 重塑查询为多头q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# (batch, num_heads, seq_len, head_dim)# 键和值保持单头,扩展到多头维度k = k.unsqueeze(1)  # (batch, 1, seq_len, head_dim)v = v.unsqueeze(1)  # (batch, 1, seq_len, head_dim)# 注意力计算scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# (batch, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # (batch, num_heads, seq_len, head_dim)# 合并多头out = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # (batch, seq_len, embed_dim)return out# 示例用法
embed_dim = 64
num_heads = 8
model = MultiQueryAttention(embed_dim, num_heads)
x = torch.randn(2, 10, embed_dim)  # (batch, seq_len, embed_dim)
output = model(x)
print(output.shape)  # torch.Size([2, 10, 64])
http://www.xdnf.cn/news/541.html

相关文章:

  • 《擦除序列》线性时间做法题解
  • 利用 FastAPI 实现三种推送方式
  • 企业微信自建应用开发回调事件实现方案
  • AI文生图工具推荐
  • swift-12-Error处理、关联类型、assert、泛型_
  • Java ThreadPoolExecutor 深度解析:从原理到实战
  • 编译Spring源码时遇到的错误
  • HDMI如何进行插入检测
  • QML中的3D功能--纹理应用
  • Linux字符设备驱动
  • ZLMediaKit 和 SRS的区别,哪个更好用?
  • 在Qt和OSG中动态改变部分3D模型数据
  • 大模型API中转平台选择指南:如何找到优质稳定的服务
  • 压滤机与锡泥产生效率
  • OzGIS:地理信息分析与处理软件
  • C语言用if else求三个数最小值的一题多解
  • c++冒泡排序实现
  • Java Web 之 简介 100问
  • 大模型时代:机遇与风险并存的AI革命
  • Java Stream API 实践指南:从基础操作到高效用法
  • 【操作系统原理03】处理机调度与死锁
  • 运筹学之模拟退火
  • 生成模型StackGAN模型详解
  • 高效的项目构建:用 Makefile 自动化你的构建过程
  • Mybatis源码01-SpringBoot启动时mybatis加载过程
  • U-Boot 启动过程详解
  • 杂记-2025年4月19日
  • Linux压缩与解压命令完全指南:tar.gz、zip等格式详解
  • JAVA 继承
  • 【EDA软件】【设计约束和分析操作方法】