当前位置: 首页 > java >正文

Grouped Query Attention (GQA) PyTorch实现

个人在网上看到的实现好像都长得奇奇怪怪的,没有简洁的感觉,因此在这里给出一种易读的GQA实现方法:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass GroupedQueryAttention(nn.Module):def __init__(self, embed_dim, num_heads, num_groups):super().__init__()assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"self.num_heads = num_headsself.num_groups = num_groupsself.head_dim = embed_dim // num_headsself.group_dim = self.num_groups * self.head_dim  # Correct: num_groups * head_dimself.scale = self.head_dim ** -0.5# Projectionsself.q_proj = nn.Linear(embed_dim, embed_dim)  # Query: full embed_dim for num_headsself.k_proj = nn.Linear(embed_dim, self.group_dim)  # Key: group_dim for num_groupsself.v_proj = nn.Linear(embed_dim, self.group_dim)  # Value: group_dim for num_groupsself.out_proj = nn.Linear(embed_dim, embed_dim)def forward(self, x):batch_size, seq_len, embed_dim = x.shape# Project inputs to q, k, vq = self.q_proj(x)  # Shape: (batch_size, seq_len, embed_dim)k = self.k_proj(x)  # Shape: (batch_size, seq_len, group_dim)v = self.v_proj(x)  # Shape: (batch_size, seq_len, group_dim)# Reshape query for multi-head attentionq = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape key and value for grouped attentionk = k.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)v = v.view(batch_size, seq_len, self.num_groups, self.head_dim).transpose(1, 2)# Shape: (batch_size, num_groups, seq_len, head_dim)# Repeat k and v to match the number of query headsheads_per_group = self.num_heads // self.num_groupsk = k.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)v = v.repeat_interleave(heads_per_group, dim=1)# Shape: (batch_size, num_heads, seq_len, head_dim)# Compute attention scoresscores = torch.matmul(q, k.transpose(-2, -1)) * self.scale# Shape: (batch_size, num_heads, seq_len, seq_len)attn = F.softmax(scores, dim=-1)out = torch.matmul(attn, v)  # Shape: (batch_size, num_heads, seq_len, head_dim)# Reshape and project outputout = out.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)out = self.out_proj(out)  # Shape: (batch_size, seq_len, embed_dim)return out# Test the model
embed_dim = 64
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(2, 10, embed_dim)  # Input shape: (batch_size, seq_len, embed_dim)
output = model(x)
print(output.shape)  # Expected output: torch.Size([2, 10, 64])

为了读懂GQA,建议读者了解一下MQA的实现,这样顺着读下来会更顺手。

一旦读懂了MQA,GQA的实现思路几乎完全一样,只是多用了一个不太常用的函数tensor.repeat_interleave。关于这个函数,直接点击链接看笔者相关文章就行了,挺好懂的。

http://www.xdnf.cn/news/424.html

相关文章:

  • 单片机如何通过串口与上位机进行数据交换
  • RAG vs. CAG vs. Fine-Tuning:如何为你的大语言模型选择最合适的“脑力升级”?
  • 使用EXCEL绘制平滑曲线
  • 从代码学习深度学习 - 优化算法 PyTorch 版
  • Vue 3 中将 ref 创建的响应式对象数据转换为普通(非响应式)的数据
  • JAVA IO、BIO、NIO、AIO及零拷贝
  • Warcraft Logs [Classic] [WCL] Usage Wizard <HTOC>
  • FPGA系列之DDS信号发生器设计(DE2-115开发板)
  • 睡前小故事数据集分享
  • 腾讯wxg企业微信 后端开发一面
  • [Swift]Xcode模拟器无法请求http接口问题
  • 阿里云Clickhouse 冷热数据分层存储 实战记录
  • 【图片识别改名工具】图片文件区域OCR识别并自动重命名,批量识别指定区域根据指定识别文字批量改名,基于WPF和阿里云的技术方式实现
  • 二进制裁剪命令mips-linux-gnu-strip 命令的使用
  • NoSQl注入学习
  • 【Flutter动画深度解析】性能与美学的完美平衡之道
  • 多人五子棋联机对战平台 测试报告
  • 【绘制图像轮廓】图像处理(OpenCV) -part7
  • leetcode哈希表(六)-三数相加
  • P11299 [NOISG 2021 Finals] Fraud 题解
  • PHP异常处理__Exception类
  • 实验4基于神经网络的模式识别实验
  • opencv 图像的旋转
  • linux下C++性能调优常用的工具
  • 真实波幅策略思路
  • 数据驱动增长:大数据与营销自动化的结合之道
  • 芝法酱躺平攻略(21)——kafka安装和使用
  • Chromium 134 编译指南 macOS篇:编译优化技巧(六)
  • Warcraft Logs [Classic] [WCL] BOSS ID query
  • 分析虚幻引擎编辑器中使用 TAA 或 TSR 时角色眨眼导致的眼睛模糊问题