当前位置: 首页 > news >正文

大模型-attention汇总解析之-MHA

一、MHA(Multi-Head Attention)

1.1 MHA 原理

MHA(Multi-Head Attention)称为多头注意力,开山之作所提出的一种 Attention 计算形式,它是当前主流 LLM 的基础工作。在数学原理上,多头注意力 MHA 等价于多个独立的单头注意力的拼接, MHA 可以形式地记为:

公式展开下如下:

Attention 的计算公式如下: 

Attention 计算模型结构和MHA的模型结构示意图:

在实践中,为了减少计算复杂度和内存占用,通常会设置 ,其中 d 是模型的维度,h 是缩放因子(也称为头数,即多头注意力中的头的数量)。对于 LLaMA2-7b 模型:模型维度 d = 4096,多头数 h = 32, 因此,d_k = d_v = 128(即 4096 / 32.

这里我们只考虑了主流的自回归 LLM 所用的 Causal Attention,因此在 token by token 递归生成时,新预测出来的第 i+1个 token,并不会影响到已经算好的 前面的i个K, V的值,因此这部分K, V结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的 KV Cache。下面是kv cache的示意图。

 

 多头MHA的Kv cache 的简单实现:

import torch
import torch.nn as nn
import mathclass CachedAttention(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_heads# 定义线性变换层,将输入映射到Query、Key和Value空间self.q_proj = nn.Linear(d_model, d_model)self.k_proj = nn.Linear(d_model, d_model)self.v_proj = nn.Linear(d_model, d_model)# 定义输出线性变换层,将注意力计算结果映射回原维度self.out_proj = nn.Linear(d_model, d_model)def forward(self, x, kv_cache=None):b, t, c = x.shape# 将输入x通过线性变换得到Query,并调整形状和维度q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)# 将输入x通过线性变换得到Key,并调整形状和维度k = self.k_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)# 将输入x通过线性变换得到Value,并调整形状和维度v = self.v_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)if kv_cache is not None:cached_k, cached_v = kv_cache# 将缓存中的Key和当前计算的Key拼接起来k = torch.cat((cached_k, k), dim=2)# 将缓存中的Value和当前计算的Value拼接起来v = torch.cat((cached_v, v), dim=2)# 计算注意力分数,这里除以根号下head_dim是为了缩放attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))# 对注意力分数进行softmax归一化attn = attn.softmax(dim=-1)# 根据注意力分数对Value进行加权求和y = (attn @ v).transpose(1, 2).contiguous().view(b, t, c)# 通过输出线性变换层得到最终输出y = self.out_proj(y)return y, (k, v)

 

1.2 存在的问题

看下attention计算的公式:

 

从上面的可以知道:

  1. attention2的计算和Q2, K1, K2, V1, V2有关系。

  2. 如果我们把之前已经计算好的K1, V1 保存起来,那么这一步的计算量就节省了,从而可以使用空间换时间,加快计算速度。

  3. 人们总是不断的追求极致, 那么能不能再进一步的节省空间,减少KV cache的同时,保证计算的效果还能达到要求呢。

所以后续就出现了一系列的attention的优化方法。这里先上一张简洁明了的示意图。后续再聊

http://www.xdnf.cn/news/736543.html

相关文章:

  • 电工基础【2】自锁、互锁、正反转电路
  • leetcode动态规划—完全背包系列
  • 【论文阅读】DanceGRPO: Unleashing GRPO on Visual Generation
  • ✨1.1.1 按位与运算替代求余运算优化场景
  • 在 Android 上备份短信:保护您的对话
  • N2语法 状態
  • win7怎么关闭开机自启动程序?
  • 深入理解 MySQL 隔离级别:理论与实战
  • 百年陈韵,三代匠心——陈汇堂新会陈皮的传承与新生
  • linux 中路由解决方案1
  • Cross-Encoder(交叉编码器)和 Bi-Encoder(双编码器)
  • 页面表格、模型、脚本这三者之间的数据是如何传输的?尤其是模型(Model)到底是怎样的运作原理与数据流转?
  • 判断质数的基础方法
  • Maven高级篇
  • Selenium操作指南(全)
  • 本地部署AI工作流
  • vivado仿真文件的相对地址设置方法
  • LangChain第二页_【教程】翻译完了
  • 前端面试之Proxy与Reflect
  • tryhackme——Windows Internals
  • PyQt6基础_QtCharts绘制横向柱状图
  • 代码随想录算法训练营第60期第五十二天打卡
  • 六步完成软件验收:从计划到终验的全面指南(一)
  • 【瑶池数据库训练营及解决方案本周精选(探索PolarDB,参与RDS迁移、连接训练营)】
  • mobile app 工具简要对比
  • 秒出PPT正式改名秒出AI,开启AI赋能新体验!
  • 数字人革新教育:开启智慧教学新时代
  • 力扣面试150题--二叉树的层平均值
  • 探讨分贝计在医疗环境中的具体应用及其重要性
  • 基于VU37P的高性能采集板卡