当前位置: 首页 > news >正文

LLM笔记(九)KV缓存(2)

文章目录

    • 1. 背景与动机
    • 2. 不使用 KV Cache 的情形
      • 2.1 矩阵形式展开
      • 2.2 计算复杂度
    • 3. 使用 KV Cache 的优化
      • 3.1 核心思想
      • 3.2 矩阵形式展开
      • 3.3 计算复杂度对比
    • 4. 总结
    • 5. GPT-2 中 KV 缓存的实现分析
      • 5.1 缓存的数据结构与类型
      • 5.2 在注意力机制 (`GPT2Attention`) 中使用缓存
      • 5.3 缓存的更新机制 (`Cache.update`)
      • 5.4 在模型整体 (`GPT2Model`) 的 `forward` 方法中处理
      • 5.5 因果掩码 (Causal Mask) 与 KV 缓存的配合
      • 5.6 支持多种高效的注意力实现
      • 5.7 KV 缓存的完整工作流程 (自回归生成)
        • 5.7.1 初始步骤 (t=0):
        • 5.7.2 后续步骤 (t > 0):
      • KV 缓存的显著优势

  1. 看图学kv 很形象清楚
  2. gpt2源码
  3. 分析transformer模型的参数量、计算量、中间激活、KV cache量化分析了缓存
  4. kv解读

1. 背景与动机

在自回归生成(autoregressive generation)任务中,Transformer 解码器需要在每一步中根据前面已生成的所有 token 重新计算注意力(Attention),这会产生大量重复计算。引入 KV Cache(Key–Value Cache)后,能够将已计算的键值对缓存下来,仅对新增的 Query 进行点乘与加权,从而大幅降低时间与算力开销。

2. 不使用 KV Cache 的情形

2.1 矩阵形式展开

  • 第 1 步(生成第一个 token)

    Q 1 , K 1 , V 1 ∈ R 1 × d Q_1, K_1, V_1 \in \mathbb{R}^{1\times d} Q1,K1,V1R1×d

    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ⊤ d ) , V 1 Attention_1 = \mathrm{softmax}\Bigl(\frac{Q_1 K_1^\top}{\sqrt d}\Bigr),V_1 Attention1=softmax(d Q1K1),V1

  • 第 2 步(生成第二个 token)
    构造全序列的矩阵:
    image.png

    需重算完整注意力矩阵:

    A t t e n t i o n 1 : 2 = s o f t m a x ( Q 1 : 2 K 1 : 2 ⊤ d ) , V 1 : 2 Attention_{1:2} = \mathrm{softmax}\Bigl(\frac{Q_{1:2}K_{1:2}^\top}{\sqrt d}\Bigr),V_{1:2} Attention1:2=softmax(d Q1:2K1:2),V1:2

    计算出一个 2 × 2 2\times 2 2×2 矩阵,但我们只取最后一行作为输出。

  • 第 n 步

    Q 1 : n , K 1 : n , V 1 : n ∈ R n × d , A t t e n t i o n 1 : n = s o f t m a x ( Q 1 : n K 1 : n ⊤ d ) , V 1 : n Q_{1:n},K_{1:n},V_{1:n}\in\mathbb{R}^{n\times d},\quad Attention_{1:n} = \mathrm{softmax}\Bigl(\tfrac{Q_{1:n}K_{1:n}^\top}{\sqrt d}\Bigr),V_{1:n} Q1:n,K1:n,V1:nRn×d,Attention1:n=softmax(d Q1:nK1:n),V1:n

    每步均重新构建并计算 n × n n\times n n×n 注意力矩阵。

2.2 计算复杂度

  • 注意力矩阵构建 O ( n 2 ⋅ d ) O(n^2\cdot d) O(n2d)

  • 整体推理阶段:若生成总长度为 N N N,则总复杂度近似为

    ∑ n = 1 N O ( n 2 d ) ; = ; O ( N 3 d ) \sum_{n=1}^N O(n^2 d);=;O(N^3 d) n=1NO(n2d);=;O(N3d),

    由于每步都做重复计算,效率极低。

3. 使用 KV Cache 的优化

3.1 核心思想

  • 缓存已计算的 K, V:对于前序列位置的键值对,只需计算一次并存储。

  • 仅对新增 Query 进行点乘:第 n n n 步仅需计算 Q n Q_n Qn 与所有缓存 K 的点乘,得到长度为 n n n 的注意力权重,再加权叠加对应的 V。

3.2 矩阵形式展开

  • 第 1 步:如前,无缓存,计算
    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ⊤ / d ) , V 1 Attention_1 = \mathrm{softmax}(Q_1K_1^\top/\sqrt d),V_1 Attention1=softmax(Q1K1/d ),V1.

  • 第 2 步

    • 新增 Q 2 ∈ R 1 × d Q_2\in\mathbb{R}^{1\times d} Q2R1×d

    • 缓存矩阵已扩展为

      image.png

    • 只做一次 1 × 2 1\times 2 1×2 点乘:

      A t t e n t i o n 2 = s o f t m a x ( Q 2 K c a c h e ⊤ d ) , V c a c h e Attention_2 = \mathrm{softmax}\Bigl(\tfrac{Q_2 K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} Attention2=softmax(d Q2Kcache),Vcache,

      输出即为所需的 1 × d 1\times d 1×d 向量。

  • 第 n 步

    K c a c h e ∈ R n × d , V c a c h e ∈ R n × d , A t t e n t i o n n = s o f t m a x ( Q n K c a c h e ⊤ d ) , V c a c h e K_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad V_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad Attention_n = \mathrm{softmax}\Bigl(\tfrac{Q_n K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} KcacheRn×d,VcacheRn×d,Attentionn=softmax(d QnKcache),Vcache.

3.3 计算复杂度对比

模式每步复杂度总体复杂度(生成长度 N N N
无 Cache O ( n 2 d ) O(n^2 d) O(n2d) O ( N 3 d ) O(N^3 d) O(N3d)
有 KV Cache O ( n d ) O(n d) O(nd) ∑ n = 1 N O ( n d ) = O ( N 2 d ) \displaystyle\sum_{n=1}^N O(n d)=O(N^2 d) n=1NO(nd)=O(N2d)
  • 加速比:从二次方级别 O ( n 2 ) O(n^2) O(n2) 降到线性级别 O ( n ) O(n) O(n),对长序列提升显著。

4. 总结

  1. 多头注意力(Multi-Head)
    每个 head 独立缓存自己的 K, V 矩阵,计算时分别点乘再拼接。总体计算与存储线性可扩展。

  2. 缓存管理

    • 内存占用:缓存矩阵大小随生成长度增长,应考虑清理过旧不再需要的序列(如 sliding window)。

    • Batch 推理:对多条序列并行生成时,可为每条序列维护独立缓存,或统一按最大长度对齐。

  3. 硬件优化

    • 内存带宽:KV Cache 减少重复内存载入,对带宽友好;

    • 并行度:线性点乘更易与矩阵乘加(GEMM)指令级并行融合。

  4. 实践中常见问题

    • Cache 不命中:若使用 prefix-tuning 等技术动态修改 key/value,需谨慎处理缓存一致性。
    • 数值稳定性:长序列高维 softmax 易出现梯度消失/爆炸,可结合温度系数或分段归一化。

5. GPT-2 中 KV 缓存的实现分析

GPT-2(以及许多其他基于 Transformer 的自回归模型)在生成文本时,为了提高效率,会使用一种称为 KV 缓存 (Key-Value Cache) 的机制。其核心思想是:在生成第 t 个 token 时,计算注意力所需的键 (Key) 和值 (Value) 向量可以部分来自于已经生成的 t-1 个 token。通过缓存这些历史的 K 和 V 向量,可以避免在每一步生成时都对整个已生成序列重新进行昂贵的 K 和 V 计算。

5.1 缓存的数据结构与类型

Hugging Face Transformers 库为 GPT-2 提供了灵活的缓存管理机制,主要通过 Cache 基类及其子类实现。

  • Cache (基类): 定义了缓存对象的基本接口,例如 update (更新缓存) 和 get_seq_length (获取当前缓存的序列长度) 等方法。
  • DynamicCache:
    • 这是自回归生成时最常用的缓存类型。
    • 它允许缓存的序列长度动态增长。当生成新的 token 时,新计算出的 K 和 V 向量会被追加到已有的缓存后面。
    • 不需要预先分配固定大小的内存,更加灵活,但可能在内存管理上有一些开销。
  • StaticCache:
    • 在创建时就需要预先分配固定大小的内存空间来存储 K 和 V 向量。
    • 适用于已知最大生成长度或需要更可控内存占用的场景。
    • 如果生成的序列长度超过了预分配的大小,可能会出错或需要特殊处理。
  • EncoderDecoderCache:
    • 主要用于 Encoder-Decoder 架构的模型 (如 T5, BART)。
    • 它内部会分别管理编码器-解码器注意力(交叉注意力)的 KV 缓存和解码器自注意力的 KV 缓存。
    • GPT-2 是一个仅解码器 (Decoder-only) 模型,所以主要关注自注意力的缓存。
# 相关类的导入,展示了缓存工具的多样性
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache

5.2 在注意力机制 (GPT2Attention) 中使用缓存

GPT2Attention 类的 forward 方法是 KV 缓存机制的核心应用点。

class GPT2Attention(nn.Module):  ...  def forward(  self,  hidden_states: Optional[Tuple[torch.FloatTensor]],  layer_past: Optional[Tuple[torch.Tensor]] = None, # 旧版本的缓存参数名  past_key_value: Optional[Cache] = None,           # 新版本的缓存对象  attention_mask: Optional[torch.FloatTensor] = None,  head_mask: Optional[torch.FloatTensor] = None,  use_cache: Optional[bool] = False,  output_attentions: Optional[bool] = False,  cache_position: Optional[torch.LongTensor] = None, # 指示新token在缓存中的位置  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  # 1. 计算当前输入 hidden_states 的 Q, K, V        # self.c_attn 是一个线性层,通常一次性计算出 Q, K, V 然后分割  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)  # 2. 将 Q, K, V 重塑为多头形式 (batch_size, num_heads, seq_len, head_dim)        query = self._split_heads(query, self.num_heads, self.head_dim)  key = self._split_heads(key, self.num_heads, self.head_dim)  value = self._split_heads(value, self.num_heads, self.head_dim)  # 3. KV 缓存处理  if past_key_value is not None:  # 如果是 EncoderDecoderCache,根据是否交叉注意力选择正确的缓存  if isinstance(past_key_value, EncoderDecoderCache):  # ... (GPT-2 不直接使用此逻辑,但展示了其通用性)  pass  # 使用 cache_position 来更新缓存中的特定位置  cache_kwargs = {"cache_position": cache_position}  # 调用缓存对象的 update 方法  # key 和 value 是当前新计算的 K, V            # self.layer_idx 标识当前是哪一层的缓存  key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)  # 此时的 key 和 value 包含了历史信息和当前新计算的信息  # 4. 计算注意力权重 (Q @ K^T)        # ...        attn_weights = torch.matmul(query, key.transpose(-1, -2))  # ... 应用注意力掩码 (causal mask, padding mask) ...  # 5. 计算注意力输出 (attn_weights @ V)        attn_output = torch.matmul(attn_weights, value)  # ... 合并多头,返回结果 ...  if use_cache:  # 如果使用缓存,则 present_key_value 就是更新后的 past_key_value            present_key_value = past_key_value  else:  present_key_value = None  return attn_output, present_key_value # 返回注意力的输出和更新后的缓存

关键点解释:

  • past_key_value (或 layer_past): 这是从上一个时间步或上一个调用传递过来的缓存对象。它包含了到目前为止所有先前 token 的 K 和 V 向量。
  • cache_position: 这是一个非常重要的参数,尤其是在使用了诸如 Flash Attention 2 等更高级的注意力实现时。它告诉缓存 update 方法以及注意力计算函数,新的 K 和 V 向量应该被放置在缓存张量的哪个位置。这对于正确地处理填充(padding)和动态序列长度至关重要。例如,如果当前输入的是第 t 个 token(从0开始计数),cache_position 可能就是 t
  • self.layer_idx: Transformer 模型通常由多个相同的注意力层堆叠而成。每一层都有自己独立的 KV 缓存。layer_idx 用于标识当前正在处理的是哪一层的缓存,确保数据被正确地存取。
  • use_cache: 控制是否使用和返回缓存。在训练时通常为 False(除非进行特定类型的训练,如 teacher forcing 的逐token训练),在推理(生成)时为 True

5.3 缓存的更新机制 (Cache.update)

Cache 对象的 update 方法是实现缓存的核心。虽然具体的实现会因 DynamicCacheStaticCache 而异,但其基本逻辑是:

class DynamicCache(Cache):  def __init__(self):  self.key_cache: List[torch.Tensor] = [] # 每层一个 tensor        self.value_cache: List[torch.Tensor] = [] # 每层一个 tensor        self.seen_tokens = 0 # 已缓存的token数量  def update(  self,  key_states: torch.Tensor,    # 新计算的 key        value_states: torch.Tensor,  # 新计算的 value        layer_idx: int,              # 当前层索引  cache_kwargs: Optional[Dict[str, Any]] = None,  ) -> Tuple[torch.Tensor, torch.Tensor]:  # 获取 cache_position        cache_position = cache_kwargs.get("cache_position")  # 如果是第一次更新这一层 (或缓存为空)  if layer_idx >= len(self.key_cache):  # 初始化该层的缓存张量  # ... 根据 key_states 和 value_states 的形状以及预估的最大长度(或动态调整)  self.key_cache.append(torch.zeros_like(key_states_preallocated))  self.value_cache.append(torch.zeros_like(value_states_preallocated))  # 将新的 key_states 和 value_states 写入到缓存的指定位置  # 对于 DynamicCache,通常是直接拼接或在预分配空间中按位置写入  if cache_position is not None:  # 使用 cache_position 精确地更新缓存的特定部分  # 例如: self.key_cache[layer_idx][:, :, cache_position, :] = key_states            #       self.value_cache[layer_idx][:, :, cache_position, :] = value_states            # 这里的维度可能需要根据实际实现调整  # 重要的是理解 cache_position 的作用  # 例如,如果 key_states 的形状是 (batch, num_heads, new_seq_len, head_dim)            # cache_position 的形状可能是 (batch, new_seq_len) 或广播的 (new_seq_len)            # 需要将 key_states 放置到 self.key_cache[layer_idx] 的正确"槽位"  # 对于自回归,通常 new_seq_len = 1            self.key_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=key_states)  self.value_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=value_states)  # 更新已见过的token数量  self.seen_tokens = cache_position[-1] + 1 # 取最后一个新token的位置加1  else: # 旧的、不使用 cache_position 的逻辑(通常是简单拼接)  self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)  self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)  self.seen_tokens += key_states.shape[2]  # 返回包含所有历史信息(包括刚更新的)的 K 和 V 状态  return self.key_cache[layer_idx], self.value_cache[layer_idx]

update 方法的关键职责:

  1. 接收当前新计算的 key_statesvalue_states
  2. 根据 layer_idx 找到对应层的缓存。
  3. (可选,但推荐)使用 cache_position 将新的 K, V 向量精确地放置到缓存张量的正确位置。这对于处理批处理中不同样本有不同历史长度的情况(例如,在束搜索beam search后或 speculative decoding 后),或者在有填充 token 时非常重要。
  4. 返回完整的、包含所有历史信息和当前新信息的 K, V 向量,供后续的注意力计算使用。
  5. 更新内部状态,如已缓存的 token 数量 (seen_tokens)。

5.4 在模型整体 (GPT2Model) 的 forward 方法中处理

GPT2Modelforward 方法负责协调整个模型的流程,包括缓存的初始化、传递和 cache_position 的计算。

class GPT2Model(GPT2PreTrainedModel):  def forward(  self,  input_ids: Optional[torch.LongTensor] = None,  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, # 旧版缓存元组  attention_mask: Optional[torch.FloatTensor] = None,  # ...  use_cache: Optional[bool] = None,  output_attentions: Optional[bool] = None,  output_hidden_states: Optional[bool] = None,  return_dict: Optional[bool] = None,  cache_position: Optional[torch.LongTensor] = None,  ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:  # ... (处理输入ID和嵌入) ...  inputs_embeds = self.wte(input_ids) # 词嵌入  position_embeds = self.wpe(position_ids) # 位置嵌入  hidden_states = inputs_embeds + position_embeds  # 1. 缓存初始化和类型转换  if use_cache:  if past_key_values is None: # 如果是第一次调用或没有提供缓存  # 根据配置决定使用哪种缓存,通常是 DynamicCache                # 例如:self.config.cache_implementation == "dynamic"  past_key_values = DynamicCache()  elif not isinstance(past_key_values, Cache):  # 为了兼容旧的元组形式的缓存,将其转换为新的 Cache 对象  past_key_values = DynamicCache.from_legacy_cache(past_key_values)  # else: past_key_values 保持为 None  # 2. 计算 cache_position        if cache_position is None: # 如果外部没有提供 cache_position            # 获取当前缓存中已有的 token 数量  past_seen_tokens = past_key_values.get_seq_length(self.config.num_hidden_layers) if past_key_values is not None else 0  # 当前输入序列的长度  current_seq_length = inputs_embeds.shape[1]  # cache_position 从 past_seen_tokens 开始,长度为 current_seq_length            cache_position = torch.arange(  past_seen_tokens, past_seen_tokens + current_seq_length, device=inputs_embeds.device  )  # else: 使用外部传入的 cache_position  # ... (准备注意力掩码,考虑因果关系和缓存长度) ...  # 3. 逐层传递和更新缓存  all_hidden_states = () if output_hidden_states else None  all_self_attentions = () if output_attentions else None  # next_decoder_cache 用于收集下一轮的缓存 (如果 use_cache 为 True)        # 在新的 Cache 对象设计中,past_key_values 本身会被原地更新或返回更新后的版本  # 因此,这个 next_decoder_cache 可能不再是必需的,或者其角色由 past_key_values 自身承担  for i, block in enumerate(self.h): # self.h 是 GPT2Block 的列表  # ...  # 将当前层的缓存 (如果存在) 和 cache_position 传递给 GPT2Block            # GPT2Block 内部会再将其传递给 GPT2Attention            layer_outputs = block(  hidden_states,  layer_past=None, # 旧参数,通常为None  attention_mask=extended_attention_mask,  head_mask=head_mask[i],  encoder_hidden_states=None,  encoder_attention_mask=None,  use_cache=use_cache,  output_attentions=output_attentions,  past_key_value=past_key_values, # 传递整个缓存对象  cache_position=cache_position,  )  hidden_states = layer_outputs[0] # 更新 hidden_states            # 如果 use_cache,block 会返回更新后的缓存,这里 past_key_values 已被更新  # (在 Cache 对象实现中,update 方法通常返回更新后的完整缓存状态,  #  或者直接在对象内部修改,取决于具体实现)  # ... (处理输出) ...  return BaseModelOutputWithPast(  last_hidden_state=hidden_states,  past_key_values=past_key_values if use_cache else None, # 返回更新后的缓存  hidden_states=all_hidden_states,  attentions=all_self_attentions,  )

5.5 因果掩码 (Causal Mask) 与 KV 缓存的配合

在自回归生成中,模型只能注意到当前 token 及其之前的所有 token,不能注意到未来的 token。这是通过因果掩码实现的。当使用 KV 缓存时,因果掩码的构建需要考虑到缓存中已有的 token 数量。

class GPT2Attention(_GPT2Attention):  def _update_causal_mask(  self,  attention_mask: torch.Tensor, # 原始的 attention_mask (可能包含 padding)        input_tensor: torch.Tensor,   # 当前输入的 hidden_states        cache_position: torch.Tensor,  past_key_values: Cache,       # 当前的缓存对象  output_attentions: bool,  ):  # 获取当前输入的序列长度 (通常为1,在自回归生成的每一步)  input_seq_length = input_tensor.shape[1]  # 获取缓存中已有的序列长度  past_seen_tokens = past_key_values.get_seq_length(self.layer_idx)  # 总的上下文长度 = 缓存长度 + 当前输入长度  total_context_length = past_seen_tokens + input_seq_length  # _prepare_4d_causal_attention_mask_with_cache_position 会生成一个正确的掩码  # 这个掩码会确保:  # 1. 查询 Q (来自当前输入) 只能注意到键 K (来自缓存+当前输入) 中对应位置及之前的部分。  # 2. 处理好 padding (如果 attention_mask 中有指示)。  # 形状通常是 (batch_size, 1, query_length, key_length)        # 其中 query_length 是当前输入的长度 (如1)  # key_length 是总的上下文长度 (past_seen_tokens + input_seq_length)        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(  attention_mask,  input_shape=(input_tensor.shape[0], input_seq_length), # 当前输入的形状  target_length=total_context_length, # K, V 的总长度  dtype=input_tensor.dtype,  cache_position=cache_position, # 关键!用于确定当前 Q 在 K,V 序列中的相对位置  )  return causal_mask

_prepare_4d_causal_attention_mask_with_cache_position 这个辅助函数会创建一个上三角矩阵(或类似结构),其中未来的位置会被掩盖掉(例如,设置为一个非常小的负数,以便 softmax 后变为0)。cache_position 在这里的作用是,确保即使当前查询 Q 的序列长度很短(例如为1),它在与历史的 K, V 进行比较时,依然能正确地只关注到历史和当前 K, V 中该 Q 之前的部分。

5.6 支持多种高效的注意力实现

Hugging Face Transformers 库允许 GPT-2(以及其他模型)利用更高效的注意力后端实现,例如:

  • eager: PyTorch 的标准、原生注意力实现。
  • sdpa (Scaled Dot Product Attention): PyTorch 2.0 引入的高度优化的注意力函数 torch.nn.functional.scaled_dot_product_attention。它通常比 eager模式更快,内存效率也更高,并且可以自动选择最优的底层实现(如 FlashAttention 或 memory-efficient attention)。
  • flash_attention_2: 直接集成 FlashAttention v2 库。这是一种专门为现代 GPU 设计的、IO 感知的精确注意力算法,速度非常快,内存占用小。

KV 缓存机制的设计需要与这些高效实现兼容。例如,torch.nn.functional.scaled_dot_product_attention 和 FlashAttention 都支持直接传入包含历史和当前信息的完整 K, V 张量。cache_position 在这里尤为重要,因为它可以帮助这些高效后端理解哪些部分是新的,哪些是旧的,以及如何正确应用因果掩码。

# 在 GPT2Attention 的 forward 方法中
self.config._attn_implementation 存储了选择的注意力实现方式 ("eager", "sdpa", "flash_attention_2")  ... (计算 query, key, value) ...  
... (更新 key, value 使用 past_key_value 和 cache_position) ...  
此时 key 和 value 是拼接/更新后的完整 K, V  if self.config._attn_implementation == "sdpa":  # 使用 PyTorch SDPA    # is_causal=True 会自动应用因果掩码  # attn_mask 可能需要根据 SDPA 的要求进行调整  attn_output = torch.nn.functional.scaled_dot_product_attention(  query, key, value, attn_mask=adjusted_attn_mask, dropout_p=self.attn_dropout.p, is_causal=True  )  
elif self.config._attn_implementation == "flash_attention_2":  # from flash_attn import flash_attn_func  # 可能需要对 query, key, value 的形状或数据类型进行调整以适应 flash_attn_func    # causal=True 会应用因果掩码  attn_output = flash_attn_func(  query.transpose(1, 2), # FlashAttention 可能期望 (batch, seq_len, num_heads, head_dim)        torch.stack((key.transpose(1,2), value.transpose(1,2)), dim=0), # K, V 打包  dropout_p=self.attn_dropout.p,  causal=True,  )  
else: # "eager"  # ... (标准的 PyTorch matmul 实现) ...

5.7 KV 缓存的完整工作流程 (自回归生成)

5.7.1 初始步骤 (t=0):
  • 用户提供初始的 input_ids (例如,一个 [BOS] token 或者一段提示文本)。
  • past_key_valuesNone
  • 模型 forward 方法被调用。
  • use_cache 通常为 True
  • 初始化一个空的 DynamicCache 对象作为 past_key_values
  • 计算 cache_position,此时它通常是从 0 开始的序列 (e.g., torch.arange(0, initial_input_len)).
  • 对于每一注意力层:
    • 计算当前 input_ids 对应的 Q, K, V。
    • 由于 past_key_values 刚被初始化(内部缓存为空),update 方法会将这些新计算的 K, V 存入缓存的第一批位置。
    • 使用这些 K, V (此时它们只包含当前输入的信息) 和 Q 进行注意力计算。
  • 模型输出 logits (用于预测下一个 token) 和更新后的 past_key_values (现在包含了第一个输入的 K,V)。
5.7.2 后续步骤 (t > 0):
  • 从上一步的 logits 中采样得到新的 input_ids (通常是一个新的 token)。
  • 将上一步返回的 past_key_values (包含了 t-1 步及之前所有 token 的 K,V) 作为输入传递给模型。
  • 模型 forward 方法再次被调用。
  • use_cacheTrue
  • 计算 cache_position。此时,past_key_values.get_seq_length() 会返回已缓存的 token 数量 (例如 t)。新的 cache_position 会是 torch.tensor([t]),表示这个新 token 是序列中的第 t+1 个元素 (如果从1开始计数的话,或者第 t 个位置如果从0开始计数)。
  • 对于每一注意力层:
    • 只对新输入的单个 token 计算其 Q, K, V (这些是"小"张量)。
    • 调用 past_key_values.update(new_key, new_value, layer_idx, cache_kwargs={"cache_position": cache_position})
      • update 方法会将这个新 token 的 K, V 追加到对应层缓存中已有的 K, V 之后,并返回完整的 K (包含所有 t+1 个 token) 和完整的 V。
    • 使用新 token 的 Q 和完整的 (历史+当前) K, V 计算注意力。因果掩码会确保 Q 只注意到 K,V 中它自己及之前的部分。
  • 模型输出 logits 和再次更新后的 past_key_values

这个过程一直重复,直到生成了 [EOS] token 或达到最大长度。

KV 缓存的显著优势

  1. 避免冗余计算: 这是最核心的优势。在生成第 t 个 token 时,前 t-1 个 token 的 K 和 V 向量已经计算并存储在缓存中,无需重新计算。注意力机制只需要为新的当前 token 计算 K 和 V,然后将它们与缓存中的历史 K,V 结合起来。
  2. 显著提高生成速度: 尤其对于长序列生成,每次迭代的计算量从 O(N²)(N为当前总长度)降低到接近 O(N)(主要是新 Q 与历史 K,V 的交互),因为主要计算瓶颈(K,V的生成)只针对新token进行。
  3. 支持高效的批处理生成: 虽然每个样本在批次中可能有不同的已生成长度(特别是在使用可变长度输入或某些采样策略时),通过 cache_position 和可能的填充/掩码机制,KV 缓存可以有效地处理这种情况。
  4. 与先进注意力实现的兼容性: 如前所述,KV 缓存的设计与 SDPA、FlashAttention 等高效后端良好集成,使得模型可以同时享受到算法优化和底层硬件加速的好处。
http://www.xdnf.cn/news/530803.html

相关文章:

  • RK3568解码1080P视频时遇到系统崩溃内核挂掉的解决方案
  • C语言:在操作系统中,链表有什么应用?
  • 安全强化的Linux
  • RLᵛ_ Better Test-Time Scaling by Unifying LLM Reasoners With Verifiers
  • 【TTS回顾】Bert-VITS2深度解析:融合BERT的多语言语音合成模型
  • 详细总结和讲解redis的基本命令
  • JavaScript 性能优化实战指南
  • Unity3D HUD UI性能优化方案
  • 卓力达手撕垫片:精密制造的创新解决方案与多领域应用
  • Unreal Engine: Windows 下打包 AirSim项目 为 Linux 平台项目
  • 【成品设计】STM32和UCOS-II的项目
  • 软考教材重点内容 信息安全工程师 25章 移动安全 26章 大数据安全
  • Flask 与 Django 服务器部署
  • 【成品设计】基于STM32的的宠物看护系统
  • 论文阅读--Logical quantum processor based on reconfigurable atom arrays
  • ModbusTCP转 Profinet网关:热收缩包装机智能化改造核心方案
  • 深入理解 Redisson 看门狗机制:保障分布式锁自动续期
  • chirpstack v4版本 全流程部署[ubuntu+docker]
  • Linux 移植 Docker 详解
  • LeetCode 925. 长按键入 java题解
  • MIME类型详解及应用案例
  • JVM频繁FullGC:面试通关“三部曲”心法
  • 力扣992做题笔记
  • P2P最佳网络类型
  • YOLO11解决方案之实例分割与跟踪探索
  • 2025.05.01【Barplot】柱状图的多样性绘制
  • 【图像大模型】FLUX.1-dev:深度解析与实战指南
  • 五分钟本地部署大模型
  • stata入门学习笔记——导入数据
  • 二元Logistic回归