Attention层的FLOPs计算
前置知识
设矩阵 A 的维度为 m×n,矩阵 B 的维度为 n×p,则它们相乘后得到矩阵 C 的维度为 m×p。其中,C 中每个元素的计算需要进行 n 次乘法和 n−1 次加法。也就是说,总的浮点运算次数(FLOPs)约为 m × p × (2n) ≈ 2 × m × n × p。
Attention核心部分的计算
在一个 attention head 中,假设输入序列长度为 t,每个位置的表示维度(即 embedding 维度)为 d_head。在计算 self-attention 时,主要包含两个矩阵乘法操作:
1.查询矩阵与键矩阵的转置相乘(Q × K^T),计算量为 2 × t × t × d_head;
2.得分矩阵与值矩阵相乘,计算量同样为 2 × t × t × d_head。
则核心部分的总FLOPs为 4 × t × t × d_head
由于 Transformer 中通常使用多头注意力机制,设共有 n_head 个 head,并且每个 head 的维度为 d_head,那么有 d_model = n_head × d_head。于是所有 head 总共的 FLOPs 为:
4 × t × t × d_head × n_head = 4 × t × t × d_model
可见,在只考虑 attention 核心部分时,FLOPs 与 head 数量无关,仅与序列长度呈平方关系。
含有模型参数的矩阵乘法部分的FLOPs计算
除了注意力分数的计算外,Transformer 中还涉及多个由模型权重参与的线性映射,这些运算的 FLOPs 与序列长度呈线性关系。主要包括以下几个部分:
1.Q,K,V的映射:每个为输入矩阵(t × d_model)与权重矩阵(d_model × d_model)相乘,计算量为 2 × t × d_model × d_model(乘法与加法合计);三者合计为:
FLOPs ≈ 3 × 2 × t × d_model × d_model = 6 × t × d_model × d_model
2.concat以后的映射:拼接后的张量维度仍为 t × d_model,再乘以一个 d_model × d_model 的权重矩阵,FLOPs 为:
FLOPs ≈ 2 × t × d_model × d_model
综上,所有包含模型参数的线性变换的总 FLOPs 为:
FLOPs ≈ 8 × t × d_model × d_model
这部分 FLOPs 与序列长度 t 成线性关系。
总结
FLOPs的计算量可归结为2部分,其中一部分FLOPs与序列长度t呈平方关系,另一部分与序列长度 t 成线性关系,而且前者与n_head无关