当前位置: 首页 > news >正文

Attention层的FLOPs计算

前置知识

设矩阵 A 的维度为 m×n,矩阵 B 的维度为 n×p,则它们相乘后得到矩阵 C 的维度为 m×p。其中,C 中每个元素的计算需要进行 n 次乘法和 n−1 次加法。也就是说,总的浮点运算次数(FLOPs)约为 m × p × (2n) ≈ 2 × m × n × p。

Attention核心部分的计算

在一个 attention head 中,假设输入序列长度为 t,每个位置的表示维度(即 embedding 维度)为 d_head。在计算 self-attention 时,主要包含两个矩阵乘法操作:

1.查询矩阵与键矩阵的转置相乘(Q × K^T),计算量为 2 × t × t × d_head;
2.得分矩阵与值矩阵相乘,计算量同样为 2 × t × t × d_head。

则核心部分的总FLOPs为 4 × t × t × d_head

由于 Transformer 中通常使用多头注意力机制,设共有 n_head 个 head,并且每个 head 的维度为 d_head,那么有 d_model = n_head × d_head。于是所有 head 总共的 FLOPs 为:

4 × t × t × d_head × n_head = 4 × t × t × d_model

可见,在只考虑 attention 核心部分时,FLOPs 与 head 数量无关,仅与序列长度呈平方关系。

含有模型参数的矩阵乘法部分的FLOPs计算

除了注意力分数的计算外,Transformer 中还涉及多个由模型权重参与的线性映射,这些运算的 FLOPs 与序列长度呈线性关系。主要包括以下几个部分:
1.Q,K,V的映射:每个为输入矩阵(t × d_model)与权重矩阵(d_model × d_model)相乘,计算量为 2 × t × d_model × d_model(乘法与加法合计);三者合计为:
FLOPs ≈ 3 × 2 × t × d_model × d_model = 6 × t × d_model × d_model

2.concat以后的映射:拼接后的张量维度仍为 t × d_model,再乘以一个 d_model × d_model 的权重矩阵,FLOPs 为:
FLOPs ≈ 2 × t × d_model × d_model

综上,所有包含模型参数的线性变换的总 FLOPs 为:
FLOPs ≈ 8 × t × d_model × d_model

这部分 FLOPs 与序列长度 t 成线性关系。

总结

FLOPs的计算量可归结为2部分,其中一部分FLOPs与序列长度t呈平方关系,另一部分与序列长度 t 成线性关系,而且前者与n_head无关

http://www.xdnf.cn/news/234847.html

相关文章:

  • Linux 检查口令策略设置是否符合复杂度要求
  • 《FastAPI零基础入门与进阶实战》第10篇:Token验证
  • echarts
  • Python-pandas-操作csv文件(读取数据/写入数据)及csv语法详细分享
  • MiWi|Microchip开发的专有无线通信协议,适用于低功耗、短距离的无线个人局域网【无线通信小百科】
  • 简单表管理
  • SV 仿真的常识
  • 从有线到无线:冶炼工厂的高效转型
  • C盘哪些文件删除之后无影响,可以清理磁盘空间。
  • Web应用开发指南
  • PostgreSQL中的SSL(2)
  • Missashe考研日记-day31
  • UNet 改进(21):可变形卷积UNet架构
  • Java 实现 SM4 加密解密
  • SpringAI实现AI应用-搭建知识库
  • GPU集群搭建
  • BOTA新六维力传感器PixONE:用12维度力矩与运动感测,驱动人形机器人力控未来
  • Compose笔记(二十)--TextField
  • (31)VTK C++开发示例 ---绘制立方体
  • 第 12 届蓝桥杯 C++ 青少组中 / 高级组省赛 2021 年 4 月 24 日真题
  • C++好用的打印日志类
  • 2025.4.24 JavaScript 基础学习笔记
  • [特殊字符] 蓝桥杯省赛全解析:含金量、获奖难度、参赛意义与发展价值全面剖析
  • 精华贴分享|【零敲碎打12】类筹码数据构建-散户行为倾向
  • react初学踏坑记录-if(number)到底过滤了什么
  • leetcode0075. 颜色分类-medium
  • 数学:拉马努金如何想出计算圆周率的公式?
  • 大连理工大学选修课——机器学习笔记(3):KNN原理及应用
  • 【中间件】bthread效率为什么高?
  • 12.Three.js 中的 DirectionalLight(平行光)详解指南