深入理解大模型分片优化:Late Chunking 技术解析
深入理解大模型分片优化:Late Chunking 技术解析
📌 背景:为什么需要文本分片(Chunking)
在使用大语言模型(如 Transformer)进行文本嵌入(Embedding)时,输入长度往往受限(例如 512 或 2048 token),而现实中的文本往往远超此限制。为解决这一问题,文本分片(Chunking) 技术被广泛使用。
🧠 什么是 Early Chunking vs Late Chunking
🧱 Early Chunking
在编码前,将长文本切分成多个小段(如按固定 token 数、换行符、标点等),每段单独送入模型,单独编码。
- 优点:实现简单
- 缺点:
- 每个 chunk 独立编码,上下文信息丢失
- 重复计算(重叠窗口)浪费资源
🧠 Late Chunking
先将完整文本一次性输入模型,获取全局 token 嵌入后,再根据预定义的 token 区间进行切片与聚合。
- 优点:
- 上下文信息保留(单次前向传播)
- 聚合逻辑灵活,切分方式可动态调整
- 减少模型调用次数,提高效率
🔍 使用场景
- 对大段文本做语义嵌入(如句子/段落级别表示)
- 信息抽取与摘要任务的输入预处理
- 文档级搜索与向量检索(RAG)中构建 chunk embedding
🧪 代码实践解析:chunk_by_sentences
以下是基于句号(.
)分句的分片函数,提取字符级到 token 级的 span:
from transformers import AutoModel, AutoTokenizertokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)def chunk_by_sentences(input_text: str, tokenizer: callable):inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')sep_id = tokenizer.convert_tokens_to_ids('[SEP]')token_offsets = inputs['offset_mapping'][0]token_ids = inputs['input_ids'][0]chunk_positions = [(i, int(start + 1))for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))if token_id == punctuation_mark_idand (token_offsets[i + 1][0] - token_offsets[i][1] > 0or token_ids[i + 1] == sep_id)]chunks = [input_text[x[1] : y[1]]for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)]span_annotations = [(x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)]return chunks, span_annotations
🔎 注意
[(1, 0)] + chunk_positions[:-1]
这一写法的意义是构造(start, end)
对,每一段的开始位置是前一个.
之后的 token。- 分片虽然以句号为界,但不会立即执行切分,而是记录每段的 token 索引区间。
🧠 为什么 chunks 能从字符 0 开始?
虽然 span_annotations 的起点是 token 编号(如 (1, 13)),但在构造 chunks
时,我们是基于 offset 映射(字符位置)来提取文本的:
chunks = [input_text[x[1]: y[1]]
]
x[1]
是字符起始位置,y[1]
是字符结束位置,因此文本从头截取是合理的,即使 token 从 1 开始。
🧪 late_chunking
函数详解
late_chunking
根据已有的 token span(如上面得到的 span_annotations
),在模型输出的 embedding 上做 pooling(通常是均值)。
def late_chunking(model_output: 'BatchEncoding', span_annotation: list, max_length=None):token_embeddings = model_output[0]outputs = []for embeddings, annotations in zip(token_embeddings, span_annotation):if max_length is not None:annotations = [(start, min(end, max_length - 1))for (start, end) in annotationsif start < (max_length - 1)]pooled_embeddings = [embeddings[start:end].sum(dim=0) / (end - start)for start, end in annotationsif (end - start) >= 1]pooled_embeddings = [embedding.detach().cpu().numpy() for embedding in pooled_embeddings]outputs.append(pooled_embeddings)return outputs
🧩 流程总结
- 对于每个样本中每段的
(start, end)
,提取对应 token 的 embedding; - 对该区间执行
mean pooling
; - 将结果转换为 numpy 向量,方便存储/检索。
🖼️ Early vs Late Chunking 对比图
左侧为 Early Chunking,右侧为 Late Chunking,展示了上下文保留与效率差异。
🧠 总结
特性 | Early Chunking | Late Chunking |
---|---|---|
模型调用次数 | 多次 | 一次 |
上下文信息 | 丢失 | 保留全局上下文 |
实现复杂度 | 简单 | 稍复杂(需记录 span) |
聚合策略灵活性 | 固定切片 | 高度灵活,支持任意切分逻辑 |
适合场景 | 简单 pipeline,token 限制较紧 | 向量检索、摘要、多模态、RAG 等复杂任务 |
📚 延伸阅读
- 《Scaling Transformer Embeddings via Late Interaction (Dense Retriever)》
- HuggingFace
offset_mapping
的使用指南 - 向量数据库(如 FAISS, Qdrant)中的 Chunking 策略
参考说明
作者:易迟
链接:https://zhuanlan.zhihu.com/p/885347223