大模型中的KV Cache
1. KV Cache的定义与核心原理
KV Cache(Key-Value Cache)是一种在Transformer架构的大模型推理阶段使用的优化技术,通过缓存自注意力机制中的键(Key)和值(Value)矩阵,避免重复计算,从而显著提升推理效率。
原理:
-
自注意力机制:在Transformer中,注意力计算基于公式:
Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V = ∑ i = 1 n w i v i (加权求和形式) \begin{split} \text{Attention}(Q, K, V) &= \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V \\ &= \sum_{i=1}^n w_i v_i \quad \text{(加权求和形式)} \end{split} Attention(Q,K,V)=softmax(dkQK⊤)V=i=1∑nwivi(加权求和形式)
其中,Q(Query)、K(Key)、V(Value)由输入序列线性变换得到。 -
缓存机制:在生成式任务(如文本生成)中,模型以自回归方式逐个生成token。首次推理时,计算所有输入token的K和V并缓存;后续生成时,仅需为新token计算Q,并从缓存中读取历史K和V进行注意力计算。
-
复杂度优化:传统方法的计算复杂度为O(n²),而KV Cache将后续生成的复杂度降为O(n),避免重复计算历史token的K和V。
2. KV Cache的核心作用
-
加速推理:通过复用缓存的K和V,减少矩阵计算量,提升生成速度。例如,某聊天机器人应用响应时间从0.5秒缩短至0.2秒。
-
降低资源消耗:显存占用减少约30%-50%(例如移动端模型从1GB降至0.6GB),支持在资源受限设备上部署大模型。
-
支持长文本生成:缓存机制使推理耗时不再随文本长度线性增长,可稳定处理长序列(如1024 token以上)。
-
保持模型性能:仅优化计算流程,不影响输出质量。
3. 技术实现与优化策略
实现方式:
-
数据结构
- KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为
(batch_size, num_heads, seq_len, k_dim)
和(batch_size, num_heads, seq_len, v_dim)
。
- KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为
-
两阶段推理:
- 初始化阶段:计算初始输入的所有K和V,存入缓存。
- 迭代阶段:仅计算新token的Q,结合缓存中的K和V生成输出,并更新缓存。
• 代码示例(Hugging Face Transformers):设置model.generate(use_cache=True)
即可启用KV Cache。
优化策略:
-
稀疏化(Sparse):仅缓存部分重要K和V,减少显存占用。
-
量化(Quantization):将K和V矩阵从FP32转为INT8/INT4,降低存储需求。
共享机制(MQA/GQA):
-
Multi-Query Attention (MQA):所有注意力头共享同一组K和V,显存占用降低至1/头数。
-
Grouped-Query Attention (GQA):将头分组,组内共享K和V,平衡性能和显存。
4. 挑战与局限性
-
显存压力:随着序列长度增加,缓存占用显存线性增长(如1024 token占用约1GB显存),可能引发OOM(内存溢出)。
-
冷启动问题:首次推理仍需完整计算K和V,无法完全避免初始延迟。
5、python实现
import torch
import torch.nn as nn# 超参数
d_model = 4
n_heads = 1
seq_len = 3
batch_size = 3# 初始化参数(兼容多头形式)
Wq = nn.Linear(d_model, d_model, bias=False)
Wk = nn.Linear(d_model, d_model, bias=False)
Wv = nn.Linear(d_model, d_model, bias=False)# 生成模拟输入(整个序列一次性输入)
input_sequence = torch.randn(batch_size, seq_len, d_model) # [B, L, D]# 初始化 KV 缓存(兼容多头格式)
kv_cache = {"keys": torch.empty(batch_size, 0, n_heads, d_model // n_heads), # [B, T, H, D/H]"values": torch.empty(batch_size, 0, n_heads, d_model // n_heads)
}# 因果掩码预先生成(覆盖最大序列长度)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() # [L, L]'''
本循环是将整句话中的token一个一个输入,并更新KV缓存;
所以无需显示的因果掩码,因为因果掩码只用于计算注意力权重时,而计算注意力权重时,KV缓存中的key和value已经包含了因果掩码的信息。'''for step in range(seq_len):# 1. 获取当前时间步的输入(整个批次)current_token = input_sequence[:, step, :] # [B, 1, D]# 2. 计算当前时间步的 Q/K/V(保持三维结构)q = Wq(current_token) # [B, 1, D]k = Wk(current_token) # [B, 1, D]v = Wv(current_token) # [B, 1, D]# 3. 调整维度以兼容多头格式(关键修改点)def reshape_for_multihead(x):return x.view(batch_size, 1, n_heads, d_model // n_heads).transpose(1, 2) # [B, H, 1, D/H]# 4. 更新 KV 缓存(增加时间步维度)kv_cache["keys"] = torch.cat([kv_cache["keys"], reshape_for_multihead(k).transpose(1, 2) # [B, T+1, H, D/H]], dim=1)kv_cache["values"] = torch.cat([kv_cache["values"],reshape_for_multihead(v).transpose(1, 2) # [B, T+1, H, D/H]], dim=1)# 5. 多头注意力计算(支持批量处理)q_multi = reshape_for_multihead(q) # [B, H, 1, D/H]k_multi = kv_cache["keys"].transpose(1, 2) # [B, H, T+1, D/H]print("q_multi shape:", q_multi.shape)print("k_multi shape:", k_multi.shape)# 6. 计算注意力分数(带因果掩码)attn_scores = torch.matmul(q_multi, k_multi.transpose(-2, -1)) / (d_model ** 0.5)print("attn_scores shape:", attn_scores.shape)# attn_scores = attn_scores.masked_fill(causal_mask[:step+1, :step+1], float('-inf'))# print("attn_scores shape:", attn_scores.shape)# 7. 注意力权重计算attn_weights = torch.softmax(attn_scores, dim=-1) # [B, H, 1, T+1]# 8. 加权求和output = torch.matmul(attn_weights, kv_cache["values"].transpose(1, 2)) # [B, H, 1, D/H]# 9. 合并多头输出output = output.contiguous().view(batch_size, 1, d_model) # [B, 1, D]print(f"Step {step} 输出:", output.shape)