[AI算法] 什么事RoPE scaling
文章目录
- RopeScaling 的作用
- 💡 RopeScaling 的核心思想:
- 常见的 RoPE Scaling 方法
- Dynamic NTK-Aware Scaling
- 核心原理
- 实现方式(伪代码示例)
- 优点与效果
- 应用场景
- 总结对比表
- YaRN技术
RopeScaling 的作用
- ✅ 场景背景:
- 模型在训练时使用的最大上下文长度是有限的(如 2048 或 4096 tokens)。
- 实际推理或部署时,可能需要处理更长的文本(如 8192 或 32768 tokens)。
💡 RopeScaling 的核心思想:
- 对原始 RoPE 的频率进行缩放(scaling),使得模型可以外推到更长的位置。
常见的 RoPE Scaling 方法
Dynamic NTK-Aware Scaling
- 是一种基于 Neural Tangent Kernel (NTK- 神经切核) 理论的位置编码缩放方法,可以根据当前输入的序列长度动态地调整 RoPE 的频率参数,从而实现对长序列的良好支持。
核心原理
-
NTK 理论基础:根据 NTK 理论,神经网络在无限宽极限下可以看作一个核函数(Kernel),其泛化能力与输入数据的分布和位置编码密切相关。
-
在 RoPE 中,位置相关的旋转角度由一组固定的频率(inv_freq)决定;Dynamic NTK 方法通过调整这些频率,使模型在不同长度下保持相似的注意力行为。
-
动态缩放公式
设:
L origin L_{\text{origin}} Lorigin:原始训练时的最大位置长度(如 2048)
L current L_{\text{current}} Lcurrent:当前输入的实际长度
α = ( L current L origin ) γ \alpha = \left( \frac{L_{\text{current}}}{L_{\text{origin}}} \right)^\gamma α=(LoriginLcurrent)γ:缩放因子(通常 γ = 0.3 \gamma=0.3 γ=0.3 或 0.5 0.5 0.5)
则新的频率为: new_inv_freq = original_inv_freq × α \text{new\_inv\_freq} = \text{original\_inv\_freq} \times \alpha new_inv_freq=original_inv_freq×α
其中 original_inv_freq 通常是按如下方式生成的(以 LLaMA 为例):
def precompute_freqs(dim: int, end: int, theta: float = 10000.0):freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))t = torch.arange(end, device=freqs.device)freqs = torch.outer(t, freqs).float()return freqs
在 Dynamic NTK 中,只需修改 end 或 freqs 即可实现动态调整。
实现方式(伪代码示例)
import math
import torchdef dynamic_ntk_scaling(max_seq_len, model_max_seq_len=2048, base_theta=10000.0, scaling_factor=1.0):if max_seq_len <= model_max_seq_len:return precompute_freqs(model_max_seq_len, base_theta)# 根据 NTK 论文建议,使用 log 缩放alpha = max(1, scaling_factor * (math.log(max_seq_len / model_max_seq_len) + 1))dim = ... # embedding dimensioninv_freq = 1.0 / (base_theta ** (torch.arange(0, dim, 2).float() / dim)) * alphat = torch.arange(max_seq_len, device=inv_freq.device)freqs = torch.outer(t, inv_freq).float()return freqs
优点与效果
应用场景
总结对比表
✅ 最佳实践建议
- 如果你正在使用 LLaMA 类模型并希望支持更长上下文(如 8k、32k),推荐使用 Dynamic NTK-Aware Scaling;
- 在加载模型权重时,替换原有的 inv_freq 为动态计算版本即可;
- 可结合 flash attention 和 paged attention 进一步优化显存效率;
- 如需极致长上下文支持,可考虑结合 YaRN 技术。