当前位置: 首页 > news >正文

GQA(Grouped Query Attention):分组注意力机制的原理与实践《三》

GQA 是一种在多头注意力中共享 Key/Value,但拥有独立 Query 的结构,用于提升推理效率、减少冗余计算。

✅ GQA vs 多头注意力 (MHA)

•	MHA:每个 head 都有独立的 Q/K/V
•	GQA:每个 head 有独立 Q,但共享组内 K/V

🚀 GQA 简易 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GQAAttention(nn.Module):def __init__(self, hidden_size, num_heads, num_kv_groups=1):super().__init__()assert hidden_size % num_heads == 0self.hidden_size = hidden_sizeself.num_heads = num_headsself.head_dim = hidden_size // num_headsself.num_kv_groups = num_kv_groupsassert num_heads % num_kv_groups == 0# 每个 head 的 Q 独立self.q_proj = nn.Linear(hidden_size, hidden_size)# K 和 V 是共享的(Group-wise),因此维度为 num_kv_groups * head_dimself.k_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.v_proj = nn.Linear(hidden_size, self.head_dim * num_kv_groups)self.out_proj = nn.Linear(hidden_size, hidden_size)def forward(self, x):B, T, _ = x.size()# Q: [B, T, H * D] → [B, H, T, D]q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)# K/V: [B, T, G * D] → [B, G, T, D]k = self.k_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)v = self.v_proj(x).view(B, T, self.num_kv_groups, self.head_dim).transpose(1, 2)# 将 KV 扩展到每个 head(head 与 group 对应)heads_per_group = self.num_heads // self.num_kv_groupsk = k.repeat_interleave(heads_per_group, dim=1)v = v.repeat_interleave(heads_per_group, dim=1)# Attention: [B, H, T, D] x [B, H, D, T] → [B, H, T, T]attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_probs = F.softmax(attn_weights, dim=-1)attn_output = torch.matmul(attn_probs, v)  # [B, H, T, D]attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, self.hidden_size)return self.out_proj(attn_output)

🧠 参数解释

参数名 含义
hidden_size 模型总隐藏维度
num_heads Query 的数量
num_kv_groups K/V 分组数量(小于 num_heads)
heads_per_group 每组多少个 head 共享一个 KV

📌 举例:设置说明

GQAAttention(hidden_size=768, num_heads=12, num_kv_groups=4)

含义为:
• 有 12 个 Q-head(每个独立)
• 只有 4 个 K/V group(被共享)
• 每 3 个 Q-head 共享 1 个 KV group

✅ GQAAttention 测试函数(PyTorch)

def test_gqa():import torch# 参数设置batch_size = 2seq_len = 10hidden_size = 768num_heads = 12num_kv_groups = 4# 构造 GQA 模块gqa = GQAAttention(hidden_size=hidden_size, num_heads=num_heads, num_kv_groups=num_kv_groups)# 随机构造输入:[B, T, H]dummy_input = torch.randn(batch_size, seq_len, hidden_size)# 执行前向传播output = gqa(dummy_input)# 打印输出维度print("Input shape:", dummy_input.shape)print("Output shape:", output.shape)# 断言输出维度匹配输入assert output.shape == (batch_size, seq_len, hidden_size), "Output shape mismatch!"print("✅ GQA forward pass test passed.")if __name__ == "__main__":test_gqa()

输出

Input shape: torch.Size([2, 10, 768])
Output shape: torch.Size([2, 10, 768])
✅ GQA forward pass test passed.
http://www.xdnf.cn/news/869185.html

相关文章:

  • AIGC1——AIGC技术原理与模型演进:从GAN到多模态融合的突破
  • 基础电学笔记
  • 6.4 C++作业
  • DeepSeek 赋能医疗新生态:远程会诊智能化转型之路
  • Vue.js教学第十九章:Vue 工具与调试,Vue DevTools 的使用与 VS Code 插件辅助开发
  • Leetcode日记
  • PyTorch实战(8)——深度卷积生成对抗网络
  • MySQL 9.0 相较于 MySQL 8.0 引入了多项重要改进和新特性
  • 【DeepSeek】【Dify】:用 Dify 对话流+标题关键词注入,让 RAG 准确率飞跃
  • 数学复习笔记 25
  • 2025 年最新 conda 和 pip 国内镜像源
  • 2025 Vscode插件离线下载方式
  • 通过paramiko 远程在windows机器上启动conda环境并执行python脚本
  • kubernetes》》k8s》》kubectl proxy 命令后面加一个
  • Zookeeper 集群部署与故障转移
  • vue-16(Vuex 中的模块)
  • 智能推荐系统:协同过滤与深度学习结合
  • 从上下文学习和微调看语言模型的泛化:一项对照研究
  • 网络攻防技术十四:入侵检测与网络欺骗
  • `<CLS>` 向量是 `logits` 计算的“原材料”,`logits` 是基于 `<CLS>` 向量的下游预测结果
  • pikachu靶场通关笔记13 XSS关卡09-XSS之href输出
  • Spring 中注入 Bean 有几种方式?
  • 身体节奏失调现象探秘
  • Windows GDI 对象泄漏排查实战
  • Bootstrap 5学习教程,从入门到精通,Bootstrap 5 容器(Container)语法知识点及案例代码详解(4)
  • RAG-Gym:一个用于优化带过程监督的代理型RAG的统一框架
  • macOS 连接 Docker 运行 postgres
  • HarmonyOS 实战:给笔记应用加防截图水印
  • 【Kdump专题】kexec加载捕获内核和 makedumpfile保存Vmcore
  • GPUCUDA 发展编年史:从 3D 渲染到 AI 大模型时代(上)