当前位置: 首页 > news >正文

kvcache比赛记录

一些简单的记录


文章目录

  • test.py注释版
  • 文档处理
  • 预计算独立的KVCache
      • 目标:模拟“预埋”的KVCache
      • 例子 walkthrough
      • 为什么这个KVCache是“有瑕疵”的?
  • 开始推理
      • 核心原理:`llm.generate` 触发的“幕后”连锁反应
      • 例子 Walkthrough:深入模型内部
        • **进入第 0 层 (`layer_ind = 0`)**
        • **进入第 1 层 (`layer_ind = 1`)**
  • choose_recompute函数分析
      • 1\. 原理讲解:核心思想是什么?
      • 2\. 功能讲解:代码是如何工作的?
      • 3\. 举例说明:模拟执行过程

所有的精度指标(F1, Precision, Recall)都是通过将 Model Answer 与这个 Dataset Answer进行对比计算出来的

test.py注释版

# 导入所有必要的库
from vllm import LLM, SamplingParams  # vLLM的核心库,用于加载模型和进行推理
from utils import load_dataset, build_qa_prompt, scorer_all, extract_after_think  # 从utils.py导入辅助函数
from transformers import AutoTokenizer  # Hugging Face的库,用于加载分词器
import torch  # PyTorch库,用于张量操作
import os  # 用于处理文件路径
import sys  # 系统库
import numpy as np  # NumPy库,用于数值计算,如此处的平均值
import json  # 用于处理JSON数据格式# --- 1. 环境与路径配置 ---
# 原理讲解:
# 这部分代码设置了所有必要的路径和参数,是脚本运行的基础。
# 您需要根据自己的环境修改这些路径,特别是`base_dir`和`model_path`。base_dir = '/home/mzq/massive_storage'  # 设置一个基础目录,用于存放最终生成的JSON结果文件
num_runs = 1  # 设置要运行的数据集样本数量,这里设为1表示只跑数据集中的第一个问题
tp = 1  # Tensor Parallelism size,张量并行大小,设为1表示使用单GPU
max_tokens = 2048  # 设置模型生成的最大token数
dataset_name = 'just_for_test'  # 指定要使用的数据集文件名(不含.json后缀)
model = "DeepSeek-R1-Distill-Qwen-14B"  # 指定要加载的模型名称
model_path = f"/data/mzq/massive_storage/{model}"  # 拼接成完整的模型存放路径# --- 2. 数据与模型加载 ---
# 原理讲解:
# 这部分代码负责将模型、分词器和数据集加载到内存中。
# 同时,它还定义了用于构建完整请求的prompt模板。# 拼接出最终结果JSON文件的完整路径
json_path = os.path.join(base_dir, f'{model}-{dataset_name}-test.json')
# 加载评测数据集
eval_dataset = load_dataset(f"/home/mzq/massive_storage/vllm-ascend-dev/data/{dataset_name}.json")# 定义用于构建prompt的模板字符串
if dataset_name == 'just_for_test':# 这是系统提示(System Prompt),告诉模型它的角色和任务规则prefix_prompt = "你是一个有帮助且知识渊博的助手。你将得到一个问题和一组从知识库中检索到的文档。请仅使用提供的上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请如实地说明。需遵从下面的指令:1、你将得到一个用户问题和一组检索到的文档。2、仅使用提供的上下文来回答问题。3、如果问题无法在上下文中找到答案,请回答:“上下文没有提供足够的信息来回答这个问题。”4、简洁并事实性回答。\n文章:\n"# 这是查询提示,放在所有文档之后,引出用户的问题query_prompt = "请基于上述文章回答下面的问题。\n问题:"# 初始化vLLM引擎
llm = LLM(model=model_path,  # 模型路径max_model_len=20000,  # 模型能处理的最大序列长度tensor_parallel_size=tp,  # 张量并行大小enforce_eager=True,  # 强制使用Eager模式,便于调试和获取中间状态enable_chunked_prefill=False,  # 禁用分块预填充dtype='bfloat16')  # 使用bfloat16数据类型以节省显存# 加载模型对应的分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 加载模型的配置文件(config.json),以获取模型层数等信息
model_config = load_dataset(f'{model_path}/config.json')
num_layer = model_config["num_hidden_layers"]  # 获取模型的总层数# --- 3. 主循环与评测逻辑 ---
# 原理讲解:
# 这是脚本的核心。它遍历数据集中的每一个问题,并执行两个关键的推理步骤:
# 1. 预计算KVCache:模拟“预埋”操作,获取每个文档块独立的、无交叉注意力的KVCache。
# 2. 正式推理:使用选择性重算策略,对完整的长文本进行推理,并记录性能和精度。# 初始化用于存储所有问题结果的列表
ttft_blend = []  # 存储每个问题的TTFT
answers = []  # 存储标准答案
result_w_caches = []  # 存储模型答案
t_df1 = []  # 存储每个问题的F1分数
t_dpr = []  # 存储每个问题的Precision
t_drecall = []  # 存储每个问题的Recall# 打开结果文件,准备以追加模式('a')写入
with open(json_path, mode='a', newline='', encoding='utf-8') as file:file.write('[\n')  # 手动写入JSON数组的开头# 遍历数据集中的每一个样本(问题)for ii, ex in enumerate(eval_dataset):dict_obj = {}  # 用于存储当前问题结果的字典if ii == num_runs:  # 如果处理的问题数量达到了设定的num_runs,则跳出循环breakanswer = ex["answers"]  # 获取标准答案列表question = ex["question"]  # 获取问题字符串# 使用utils.py中的函数构建prompt列表doc_prompts, q_prompt = build_qa_prompt(ex, query_prompt)# 组合成最终的文档块列表,结构为:[头部prompt, 文档1, 文档2, ..., 尾部prompt]doc_list = [prefix_prompt] + doc_prompts + [q_prompt]# --- 3.1 预计算独立的KVCache ---# 原理讲解:# 这是为了模拟赛题中“预埋”的、缺少交叉注意力的KVCache。# 脚本对每个文档块独立进行一次推理,然后通过`hack_kv`这个“后门”变量,# “偷”出该次推理产生的KVCache。最后将所有块的KVCache拼接起来,# 形成一个完整的、但有瑕疵的`old_kvs`,作为后续“复用”的来源。sampling_params = SamplingParams(temperature=0, max_tokens=1)  # 设置采样参数,这里只需要生成1个token来触发prefill# 获取模型内部用于传递元数据的字典recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadata# 计算每个文档块的token长度,用于后续计算重算率doc_length = [len(tokenizer.encode(prefix_prompt))]for doc in doc_prompts:doc_length.append(len(tokenizer.encode(doc)) - 1)doc_length.append(len(tokenizer.encode(q_prompt)) - 1)recompute_metadata["doc_length"] = doc_lengthrecompute_metadata["kv_done"] = False  # 标记:此时“旧”的KVCache还没准备好chunk_past_key_values = []  # 用于存储拼接好的“旧”KVCache# 遍历每一个文档块,独立计算其KVCachefor i in range(len(doc_list)):prompts = doc_list[i]llm.generate(prompts, sampling_params)  # 对当前块进行推理# 直接访问模型底层,获取所有注意力层llm_layers = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.layers# 遍历模型的每一层for j in range(num_layer):# 从我们修改的qwen2.py中,取出“偷”出来的K,V (`hack_kv`)past_key_values = llm_layers[j].self_attn.hack_kvif i == 0:  # 如果是第一个块(头部prompt)temp_k = past_key_values[0][:].clone()temp_v = past_key_values[1][:].clone()chunk_past_key_values.append([temp_k, temp_v])  # 初始化列表else:  # 如果是后续的块# 注意:[1:]是为了去掉每个块自带的BOS token,避免重复temp_k = past_key_values[0][1:].clone()temp_v = past_key_values[1][1:].clone()# 将当前块的K,V拼接到之前所有块的K,V之后chunk_past_key_values[j][0] = torch.cat((chunk_past_key_values[j][0], temp_k), dim=0)chunk_past_key_values[j][1] = torch.cat((chunk_past_key_values[j][1], temp_v), dim=0)# 将拼接好的、有瑕疵的KVCache存入模型,供后续使用llm.llm_engine.model_executor.driver_worker.model_runner.model.model.old_kvs = chunk_past_key_values# --- 3.2 执行正式的、带选择性重算的推理 ---# 原理讲解:# 现在,我们将所有文档拼接成一个完整的长文本进行推理。# 因为`kv_done`被设为True,我们修改过的`qwen2.py`会启动选择性重算逻辑,# 它会调用你在`functions.py`中写的算法来决定哪些token复用`old_kvs`,哪些重新计算。recompute_metadata["kv_done"] = True  # 标记:“旧”的KVCache已准备就绪,可以开始选择性重算了# 将所有文档块拼接成一个长字符串prompts = ''for doc in doc_list:prompts += doc# 设置正式推理的采样参数sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, ignore_eos=False)# 执行生成任务,这将触发我们的选择性重算逻辑output = llm.generate(prompts, sampling_params)# --- 3.3 计算并验证重算率 ---# 原理讲解:# 推理结束后,从模型元数据中取出`valid_list`(由`qwen2.py`在每层记录的重算token数),# 计算出实际重算token总数,再除以理论最大重算token总数,得到重算率。valid_list = recompute_metadata["valid_list"]recompute_num = 0for u in valid_list:recompute_num += u  # 累加每层的重算数量# 计算理论最大重算数 = (文档总token数) * (模型层数)total_num = np.sum(doc_length[1:-1]) * num_layerrecompute_ratio = recompute_num / total_num  # 计算重算率if recompute_ratio >= 0.3:print(f"Error!recompute ratio {recompute_ratio} is too large!")else:print("这次回答重算率没超过0.3") # wgh 自己加的输出# --- 3.4 计算性能与精度指标 ---# 原理讲解:# 从vLLM的输出中提取TTFT。对于精度,使用模型的回答与数据集中每一个标准答案进行比较,# 取F1分数最高的那一次作为最终结果。res = output[0].outputs[0].text  # 获取模型最终的文本输出print(f"问题: {question}")print(f"模型答案: {res}")print(f"正确答案: {answer}")# 计算TTFTttft = output[0].metrics.first_token_time - output[0].metrics.first_scheduled_timeprint(f"TTFT with cache: {ttft}")ttft_blend.append(ttft)  # 存入列表# 初始化当前问题的最高精度分数temp_df1 = 0temp_dpr = 0temp_drecall = 0# 遍历所有可能的标准答案for j in range(len(answer)):# 调用评分函数,注意`extract_after_think`会去掉模型的思考过程df1, dpr, drecall = scorer_all('dureader_all', extract_after_think(res), str(answer[j]))# 保留分数最高的结果if df1 > temp_df1:temp_df1 = df1if dpr > temp_dpr:temp_dpr = dprif drecall > temp_drecall:temp_drecall = drecalldf1 = temp_df1dpr = temp_dprdrecall = temp_drecall# 将当前问题的最终分数存入列表t_df1.append(df1)t_dpr.append(dpr)t_drecall.append(drecall)# --- 3.5 保存结果到JSON文件 ---dict_obj["id"] = iidict_obj["Query"] = questiondict_obj["Model Answer"] = resdict_obj["Dateset Answer"] = answerdict_obj["F1 with Dataset"] = df1dict_obj["Precision with Dataset"] = dprdict_obj["Recall with Dataset"] = drecalldict_obj["TTFT"] = ttftjson.dump(dict_obj, file, indent=4, ensure_ascii=False)  # 将字典写入JSON文件file.write(',\n')  # 手动写入逗号和换行,为下一个对象做准备# --- 4. 计算并输出最终平均结果 ---# 循环结束后,计算所有问题指标的平均值res_obj = {}res_obj["avg_ttft"] = np.mean(ttft_blend)res_obj["avg_f1 with Dataset Answer"] = np.mean(t_df1)res_obj["avg_precision with Dataset Answer"] = np.mean(t_dpr)res_obj["avg_recall with Dataset Answer"] = np.mean(t_drecall)json.dump(res_obj, file, indent=4, ensure_ascii=False)  # 写入平均结果file.write(']\n')  # 手动写入JSON数组的结尾# 打印最终的平均结果到控制台
print(f"f1: {np.mean(t_df1)}, precision: {np.mean(t_dpr)}, recall: {np.mean(t_drecall)}, ttft: {np.mean(ttft_blend)}")

文档处理

我们先分析

        answer = ex["answers"]  # 获取标准答案列表question = ex["question"]  # 获取问题字符串# 使用utils.py中的函数构建prompt列表doc_prompts, q_prompt = build_qa_prompt(ex, query_prompt)# 组合成最终的文档块列表,结构为:[头部prompt, 文档1, 文档2, ..., 尾部prompt]doc_list = [prefix_prompt] + doc_prompts + [q_prompt]

我们这里先预处理一下文档

doc_list 大致结构如下

# 开头
你是一个有帮助且知识渊博的助手。你将得到一个问题和一组从知识库中检索到的文档。请仅使用提供的上下文信息来回答问题。如果上下文中没有足够的信息来回答问题,请如实地说明。需遵从下面的指令:1、你将得到一个用户问题和一组检索到的文档。2、仅使用提供的上下文来回答问题。3、如果问题无法在上下文中找到答案,请回答:“上下文没有提供足够的信息来回答这个问题。”4、简洁并事实性回答。
文章:# 中间的文档部分
<|User|>标题:为什么鼻子两侧总是红红的?_百度知道黑头的产生 黑头是硬化油脂阻塞物,通常出现在颜面的额头、鼻子等部位,当油脂腺受到过分刺激,毛孔充满多余的油脂而造成阻寒时,在鼻头及其周围部分,经常会有油腻的感觉。这些油脂最终会硬化,经氧化后成为黑色的小点,这些小点就是被称作黑头的油脂阻塞物。 错误去黑头方法:! 1、用手挤:很多人都会用手挤,但由于指甲易藏细菌,所以容易引致皮肤发炎,而且毛孔会越变越大。 2、用刷擦:这种方法只适用于去死皮,如去黑头,作用不大,若大力擦会擦损皮肤。 各路人马总结的有效的去黑头方: 一、盐加牛奶去黑头 1.最好用没有用过的食盐,可以在刚开封时用小瓶单独装起来; 2.每次用45滴牛奶兑盐,在盐半溶解状态下开始用来按摩; 3.由于此时的盐未完全溶解仍有颗粒,所以在按摩的时候必须非常非常小力; 4.半分钟后用清水洗去,不能太# 结尾部分请基于上述文章回答下面的问题。
问题:鼻子周围红红的
回答:<|Assistant|><think>

预计算独立的KVCache

通过一个巧妙的循环,用低成本的方式(独立计算+拼接)构建了一个不包含交叉注意力的“旧”KVCache (old_kvs)。

这个 old_kvs 就是您在 functions.py 中进行决策的基准。您的算法需要判断:对于拼接后的文本,哪些位置的token受交叉注意力影响巨大,以至于我们不能使用这个“有瑕疵”的旧KVCache,而必须重新计算它们,从而“修复”这个瑕疵

我们用一个简单的例子来把整个过程走一遍。


目标:模拟“预埋”的KVCache

想象一下,在真实世界里,我们可能会提前把很多文档(比如维基百科页面)单独处理好,存下它们的KVCache。当用户提问时,我们把相关的几个文档的KVCache拿出来,直接拼接在一起,希望能快速得到答案。

这部分代码就是在模拟这个“直接拼接”的过程

例子 walkthrough

假设我们的输入 doc_list 简化后是这样的(实际是很长的文档块):

doc_list = ["头部prompt",      # i = 0"文档A:猫喜欢鱼。",  # i = 1"文档B:狗喜欢骨头。",# i = 2"尾部prompt"       # i = 3
]

并且,为了简化,我们假设模型只有1层 (num_layer = 1)。

现在,我们来逐行过一遍代码:

1. 初始化

sampling_params = SamplingParams(temperature=0, max_tokens=1)
# ...
chunk_past_key_values = [] # 准备一个空列表,用来存放我们最终拼接好的KVCache
```chunk_past_key_values` 将会是一个列表,因为模型有 `num_layer` 层,它需要为每一层都存一份KVCache。因为我们假设只有1层,所以它最后只会包含一个元素 `[[K_tensor, V_tensor]]`。**2. `for i in range(len(doc_list))` 循环开始**这个循环会执行4次,每次处理 `doc_list` 中的一个块。---
**第一次循环 (i = 0, 处理 "头部prompt")**```python
prompts = doc_list[0]  # prompts = "头部prompt"
llm.generate(prompts, sampling_params) # 模型处理 "头部prompt"
# ...
past_key_values = llm_layers[0].self_attn.hack_kv # “偷”出KVCache
# 假设 KVCache_prompt = ([K_p], [V_p])if i == 0:temp_k = past_key_values[0][:].clone() # temp_k = [K_p]temp_v = past_key_values[1][:].clone() # temp_v = [V_p]chunk_past_key_values.append([temp_k, temp_v]) # 初始化列表
  • 发生了什么?:模型只看到了 “头部prompt”,并计算出了它的KVCache。因为是第一次循环 (i==0),我们直接把这份KVCache存入 chunk_past_key_values
  • 此时 chunk_past_key_values 的状态: [ [[K_p], [V_p]] ]

第二次循环 (i = 1, 处理 “文档A:猫喜欢鱼。”)

prompts = doc_list[1] # prompts = "文档A:猫喜欢鱼。"
llm.generate(prompts, sampling_params) # 模型处理 "文档A"
# ...
past_key_values = llm_layers[0].self_attn.hack_kv
# 假设 KVCache_A = ([K_A], [V_A])# 进入 else 分支
temp_k = past_key_values[0][1:].clone() # temp_k = [K_A] (去掉BOS token)
temp_v = past_key_values[1][1:].clone() # temp_v = [V_A] (去掉BOS token)# 这是最关键的一步!
chunk_past_key_values[0][0] = torch.cat((chunk_past_key_values[0][0], temp_k), dim=0)
chunk_past_key_values[0][1] = torch.cat((chunk_past_key_values[0][1], temp_v), dim=0)
  • 发生了什么?
    1. 模型只看到了 “文档A:猫喜欢鱼。”,它完全不知道前面还有个 “头部prompt”。因此,它计算出的 KVCache_A孤立的,不包含任何关于 “头部prompt” 的信息。
    2. torch.cat 函数的作用是拼接张量。代码把新算出来的 temp_k ([K_A]) 拼接到 chunk_past_key_values 中已经存在的K张量 ([K_p]) 的后面。temp_v 同理。
  • 此时 chunk_past_key_values 的状态: [ [[K_p, K_A], [V_p, V_A]] ]

第三次循环 (i = 2, 处理 “文档B:狗喜欢骨头。”)

这个过程和第二次完全一样。模型只看到了 “文档B”,独立地计算出 KVCache_B,然后代码把它拼接到 chunk_past_key_values 的末尾。

  • 此时 chunk_past_key_values 的状态: [ [[K_p, K_A, K_B], [V_p, V_A, V_B]] ]

循环结束

经过所有循环后,chunk_past_key_values 里存储的就是一个强行拼接起来的KVCache。

为什么这个KVCache是“有瑕疵”的?

现在对比一下我们得到的模拟KVCache理想KVCache的区别:

  • 我们得到的模拟KVCache

    • K_A 是在模型只看得到"文档A" 的情况下计算出来的。
    • K_B 是在模型只看得到"文档B" 的情况下计算出来的。
    • K_B 的计算完全没有考虑到 “文档A” 的存在
  • 理想的KVCache (完全重算)

    • 当模型一次性处理 “头部prompt 文档A 文档B” 时,在计算 K_B 的时候,模型会同时关注 “头部prompt” 和 “文档A” 的内容(这就是交叉注意力 cross-attention)。
    • 因此,理想情况下的 K_B 和我们模拟出来的 K_B完全不同的。

开始推理

这个部分会调用我们的choose_recompute 函数

        ###标记缓存准备就绪recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadatarecompute_metadata["kv_done"] = True###开始推理prompts = ''for doc in doc_list:prompts += docsampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, ignore_eos = False)output = llm.generate(prompts, sampling_params)recompute_metadata = llm.llm_engine.model_executor.driver_worker.model_runner.model.model.recompute_metadata

我将结合一个具体的例子,为您详细讲解当 llm.generate 被调用时,底层发生了什么,以及您的 choose_recompute 函数是如何被一步步使用的。


核心原理:llm.generate 触发的“幕后”连锁反应

test.py 调用 output = llm.generate(prompts, sampling_params) 时,vLLM 框架会启动一个完整的、从头到尾的推理流程。因为我们修改了 qwen2.py,这个流程就变得很特别:

test.py (指挥官) -> vLLM框架 -> qwen2.py (调度官) -> functions.py (您,决策者)

这个连锁反应会在模型的每一层都发生一次。

例子 Walkthrough:深入模型内部

假设我们的模型简化为只有 2层,输入的完整 prompts 在分词后,文档部分有 10个 token

准备工作

  • recompute_metadata["kv_done"] 已经被设为 True
  • llm.llm_engine...old_kvs 里面已经存好了我们之前模拟的、有瑕疵的10个token的KVCache。

现在,llm.generate 开始执行…


进入第 0 层 (layer_ind = 0)
  1. vLLM的常规操作:模型接收完整的10个token的输入,正常计算出这一层全新的、包含了完整交叉注意力的 q_new, k_new, v_new

  2. qwen2.py 的调度:在 Qwen2Attention.forward 函数中,代码检测到 kv_doneTrue,于是它暂停了常规流程,开始执行我们的特殊逻辑。

  3. 调用您的 choose_recompute 函数

    • qwen2.py 准备好所有“情报”,调用您的函数。此时传递给您函数的参数是:
      • hidden_states: 完整的10个token的输入隐状态。
      • old_k, old_v: 我们之前准备好的、有瑕疵的KVCache。
      • layer_ind: 0
      • valid_ind: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] (在第0层,默认所有token都是候选重算对象)。
      • q, k, v: 就是第1步里算出来的 q_new, k_new, v_new
  4. 您的算法进行决策

    • 您的 choose_recompute 函数开始执行。假设您的算法(比如我之前给的“渐进式注意力重构”算法)在第0层的策略是:根据 v_newold_v 的差异,保留差异最大的50%的token进行重算。
    • 您的函数经过计算,返回了一个新的 valid_ind
      [1, 0, 1, 0, 1, 1, 0, 0, 1, 0] (5个1,5个0)
  5. qwen2.py 执行您的决策

    • qwen2.py 拿到了您返回的 [1, 0, 1, ...]
    • 它找到所有标记为 0 的位置(第1, 3, 6, 7, 9个token)。
    • 在这些位置上,它执行替换操作:用 old_kold_v 中对应位置的向量,覆盖掉 k_newv_new 中对应位置的向量。
    • 现在,这一层的KVCache变成了一个“混合体”:5个token用的是新鲜出炉、包含全局信息的K/V,另外5个token用的则是旧的、有瑕疵的K/V。
  6. 进入下一层:模型使用这个“混合KVCache”完成第0层的注意力计算,然后生成进入第1层的 hidden_states一个至关重要的细节:此时,只有那5个被标记为1的token的hidden_states才会被计算并传递到下一层。


进入第 1 层 (layer_ind = 1)
  1. vLLM的常规操作:模型接收来自上一层的、只有5个tokenhidden_states,并为它们计算出全新的 q_new, k_new, v_new

  2. qwen2.py 的调度:同样,检测到 kv_doneTrue,暂停流程。

  3. 再次调用您的 choose_recompute 函数

    • qwen2.py 再次准备“情报”并调用您。这次的参数是:
      • hidden_states: 只有5个token的隐状态。
      • old_k, old_v: 还是那份完整的、有瑕疵的KVCache。
      • layer_ind: 1
      • valid_ind: [1, 0, 1, 0, 1, 1, 0, 0, 1, 0] (这是上一层决策的结果)。
      • q, k, v: 只有5个token的 q_new, k_new, v_new
  4. 您的算法再次决策

    • 您的 choose_recompute 函数再次执行。假设您在第1层的策略是:在上一层保留的5个token中,根据注意力熵,再淘汰掉40%,只保留最重要的3个。
    • 您的函数返回了最终的 valid_ind
      [1, 0, 0, 0, 1, 1, 0, 0, 0, 0] (3个1,7个0)。注意,这个结果必须是上一步输入的子集。
  5. qwen2.py 再次执行决策

    • qwen2.py 拿到这个最终版的 valid_ind,再次进行K/V替换操作。

这个过程会贯穿模型的所有层,每一层都会调用一次您的 choose_recompute 函数,让您有机会根据当前层的信息,逐步“精炼”出那些真正需要重算的核心token。

当所有层都处理完毕后,llm.generate 才算完成了 prefill 阶段,并返回最终的 output

choose_recompute函数分析

这个函数是赛题官方提供的一个基础示例,它实现了 CacheBlend 论文中的一种简化思想。理解它的工作原理是您进行优化的基础。

下面我将为您详细讲解这个函数的原理和功能,并用一个具体的例子来模拟它的执行过程。


1. 原理讲解:核心思想是什么?

这个算法的核心思想非常直接,可以概括为以下三点:

  1. 一次性决策:它并不在模型的每一层都做决策,而是选择在一个固定的、靠前的层(这里是第二层,layer_ind == 1)做一次“一锤子买卖”的决策。
  2. 后续层沿用:一旦在第二层决定了哪些token需要重算,哪些可以复用,这个决定就会被固定下来,在后续所有更深的层(layer_ind > 1)中都沿用这个决定,不再改变。
  3. 决策标准:信息变化量:它判断一个token是否需要重算的唯一标准是:当把所有文档拼接起来后,这个token的信息表示 (V向量) 发生了多大的变化。如果一个token的 V 向量相比于它在单个文档中时的 V 向量变化巨大,就说明它受到了其他文档的强烈影响,因此必须重算才能获得正确的上下文信息。

2. 功能讲解:代码是如何工作的?

我们来逐行分析这段代码的功能。

# 设置一个固定的重算比例,这里是25%
recompute_ratio = 0.25# 获取头部和尾部prompt的长度,用于后续精确地切片出文档部分的向量
begin = doc_length[0]
end = doc_length[-1]# 获取文档部分的总token数
num_tokens = len(valid_ind)# 根据比例计算出具体要重算多少个token
topk_num = int(num_tokens * recompute_ratio)# 关键判断:只在第二层 (layer_ind == 1) 执行决策逻辑
if layer_ind == 1:if topk_num != 0:# --- 这是算法的核心计算步骤 ---# 1. v[begin:-end] 是当前层根据完整上下文算出的、全新的V向量(只取文档部分)。# 2. old_v 是我们之前预埋的、缺少交叉注意力的、有瑕疵的旧V向量。# 3. (v[...] - old_v)**2 计算两者之差的平方,得到每个元素的差异值。# 4. torch.sum(..., dim=1) 将每个token向量的所有元素差异值相加,#    得到一个代表该token总信息变化量的分数。temp_diff = torch.sum((v[begin:-end] - old_v)**2, dim=1)# 从所有token的变化量分数中,找出分数最高的 topk_num 个token的索引top_indices = torch.topk(temp_diff, k=topk_num).indices# 遍历所有文档tokenfor i in range(num_tokens):# 如果一个token本来是候选(valid_ind[i] == 1),# 但它不在我们刚刚选出的“变化最大”的top_indices列表里,# 那么就把它从重算名单中剔除(设置为0)。if valid_ind[i] == 1 and i not in top_indices:valid_ind[i] = 0else:# 如果计算出的重算数量为0,则所有token都不重算valid_ind = [0] * num_tokens# 返回最终的决策列表。
# 注意:对于所有其他层 (layer_ind != 1),这个if块不会执行,
# 函数会直接返回从上一层接收到的valid_ind,不做任何修改。
return valid_ind

3. 举例说明:模拟执行过程

假设我们的文档部分共有10个token,并且现在正好进入了第二层 (layer_ind = 1)

  1. 参数初始化

    • num_tokens = 10
    • recompute_ratio = 0.25
    • topk_num = int(10 * 0.25) = 2。算法的目标是选出最重要的2个token进行重算。
    • 此时从第0层传来的 valid_ind[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],表示所有token都是候选。
  2. 计算信息变化量

    • 代码执行 temp_diff = torch.sum(...)
    • 假设计算出的10个token的信息变化量分数是:[0.2, 8.1, 1.5, 0.9, 9.5, 3.2, 0.1, 7.4, 2.8, 1.1]
  3. 找出最重要的Token

    • 代码执行 torch.topk(temp_diff, k=2)
    • 它会找到分数最高的两个值:9.5 (在索引4的位置) 和 8.1 (在索引1的位置)。
    • 所以,top_indices 列表就是 [4, 1]
  4. 更新决策列表

    • 代码开始 for 循环,遍历索引 09
    • 索引0: 0 not in [4, 1] -> valid_ind[0] 变成 0
    • 索引1: 1 in [4, 1] -> valid_ind[1] 保持 1
    • 索引2: 2 not in [4, 1] -> valid_ind[2] 变成 0
    • 索引3: 3 not in [4, 1] -> valid_ind[3] 变成 0
    • 索引4: 4 in [4, 1] -> valid_ind[4] 保持 1
    • …以此类推…
  5. 返回决策结果

    • 循环结束后,valid_ind 变成了 [0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
    • 函数将这个列表返回给 qwen2.py
  6. 后续层的影响

    • 当模型进入第三层 (layer_ind = 2) 时,choose_recompute 函数会再次被调用。
    • 但这一次,if layer_ind == 1: 的条件不满足,所以函数会直接 return valid_ind,也就是把 [0, 1, 0, 0, 1, 0, 0, 0, 0, 0] 这个列表原封不动地返回。
    • 在所有更深的层(3, 4, 5…)都是如此。

通过这个例子,您可以看到,这个示例算法实现了一个非常简单但有效的策略:在早期(第二层)识别出受上下文影响最大的25%的token,然后就锁定这些token,在后续所有层都只为它们进行重算

http://www.xdnf.cn/news/1347733.html

相关文章:

  • 集群与负载均衡:HAProxy 与 Nginx 实践
  • 融云Im单独一个拍照或者拍摄插件Plugin
  • 自学嵌入式第二十五天:数据结构-队列、树
  • 配电网重构优化:以减小网损为目标的智能算法实现
  • 20250822给荣品RD-RK3588开发板刷Rockchip原厂的Android14时点亮荣品的8寸屏
  • SN编码升级:从“制造标记”到“数字孪生身份证”
  • There are test failures. clean deploy 异常
  • [RestGPT] RestGPT智能体
  • Bluedroid vs NimBLE
  • 20.9 QLoRA微调实战:1.5B参数Whisper-large-v2在24GB显存实现中文语音识别,CER骤降50%!
  • 使用tauri打包cocos小游戏,并在抖音小玩法中启动,拿到启动参数token
  • ​Kubernetes 详解:云原生时代的容器编排与管理
  • python selenium+pytest webUI自动化基础框架
  • Java 18 新特性及具体应用
  • linux----进度条实现和gcc编译
  • 基于海光DCU平台的cube-studio软件适配
  • BurpSuite 1.4.07.jar 怎么使用?详细安装和抓包教程(附安装包下载)
  • 前端查漏补缺
  • DAY01:【DL 第一弹】深度学习的概述
  • 机器学习在量化中的应用
  • 【计算机网络】 IPV4和IPV6区别
  • 【虚拟化】磁盘置备方式的性能损耗对比
  • MPLS原理
  • 基于SamGeo模型和地图客户端的实时图形边界提取
  • Rust Web开发指南 第一章
  • 计算机网络:TCP、UDP
  • 【Dubbo】高性能的 RPC
  • RK3506 开发板:重塑嵌入式系统领域的新标杆
  • 整数规划学习总结
  • 靶机 - SAR