Wan2.1 加速推理方法
Wan2.1 加速推理方法
flyfish
TeaCache 可在免训练的情况下将 Wan2.1 加速2倍,
在单张A800 GPU上,TeaCache对Wan2.1模型不同任务(文生视频t2v、图生视频i2v)和不同参数规模(1.3B/14B模型、480P/720P分辨率)的推理延迟优化效果
使用不同 teacache_thresh 值时,TeaCache-Wan2.1 生成的结果
1. TeaCache显著降低推理延迟,加速比可达2倍以上
- 文生视频(t2v):
- 1.3B模型(480P):原始延迟~175秒,使用TeaCache后最低降至88秒(阈值0.08),加速近2倍。
- 14B模型(720P):原始延迟~55分钟,阈值0.2时降至27分钟,加速超2倍。
- 图生视频(i2v):
- 480P:原始延迟~735秒,阈值0.26时降至300秒,加速2.4倍。
- 720P:原始延迟~29分钟,阈值0.3时降至12分钟,加速2.4倍。
规律:随着teacache_thresh
阈值提高(允许更多计算步骤被缓存复用),推理时间进一步缩短,加速效果更显著。
2. 不同模型/任务的加速效率差异
- 小模型(1.3B) vs 大模型(14B):
小模型本身推理快,TeaCache加速绝对值(秒级减少)不如大模型明显,但相对加速比相近(均接近2倍)。 - 低分辨率(480P) vs 高分辨率(720P):
高分辨率任务原始延迟更高,TeaCache优化后的时间节省更显著(如720P的i2v任务减少17分钟),说明其对计算密集型任务优化效果更突出。
3. teacache_thresh阈值的权衡
- 阈值作用:该参数控制缓存复用的激进程度,阈值越高,允许跳过的计算步骤越多,加速越快,但可能引入轻微质量损失(需结合具体场景调试)。
- 推荐实践:
- 若追求无损加速,选择较低阈值(如t2v 1.3B用0.05,对应1.5倍加速)。
- 若可接受轻微质量妥协,提高阈值以获得更高加速比(如t2v 14B用0.2,加速超2倍)。
4. 免训练特性的工程价值
TeaCache无需重新训练模型,只需调整参数即可实现加速,适用于生产环境快速部署,尤其适合对推理效率敏感的场景(如短视频生成、实时交互应用)。
TeaCache通过时间步输出波动分析与缓存复用,在不显著影响视觉质量的前提下,为Wan2.1提供了高效的推理加速方案,且对不同规模模型和任务均具有普适性。
技术原理
TeaCache的加速原理基于扩散模型时间步输出的波动分析与智能缓存复用,核心在于动态判断哪些计算步骤可跳过
一、核心洞察:时间步输出的“稳定-变化”规律
扩散模型(如图像/视频生成)通过逐步去噪生成内容:
- 早期时间步:快速生成轮廓和主体结构(如人脸、物体形状),输出波动小(变化集中在全局结构)。
- 后期时间步:细化局部细节(如纹理、光影),输出波动大(变化集中在像素级细节)。
TeaCache发现:早期步骤的输出稳定性高,相邻时间步的差异(如L1范数)常低于阈值,可直接复用缓存,避免重复计算。
二、技术核心:时间步嵌入感知的动态缓存
-
波动差异监测
- 通过计算相邻时间步的输出差异(如
rel_L1_thresh
相对阈值),判断是否复用缓存。若差异低于阈值(如0.4),直接使用前一步的缓存结果,跳过当前Transformer块的计算。 - 示例:生成视频时,若背景建筑在连续时间步中无显著变化,TeaCache会缓存其特征,仅计算人物动作等变化部分。
- 通过计算相邻时间步的输出差异(如
-
分层缓存策略
- 前层Transformer块优先缓存:扩散模型的前几层负责全局结构(如轮廓),后几层负责细节(如纹理)。TeaCache重点缓存前层输出,因为其稳定性高,复用价值大。
- 动态调整缓存粒度:根据输入内容的复杂度(如运动幅度、细节密度),自适应决定缓存范围。例如,静态图像可缓存更多步骤,动态视频则减少缓存以保留时序一致性。
-
免训练适配
- 无需修改模型参数或训练,直接在推理阶段分析输出波动。兼容LoRA、ControlNet等插件,无需重新适配。
三、阈值控制:速度与质量的权衡
teacache_thresh
参数:- 低阈值(如0.05-0.15):严格判断差异,仅缓存高度相似的步骤,实现无损加速(如Wan2.1视频加速1.6倍)。
- 高阈值(如0.2-0.3):允许更多缓存复用,加速显著(如Wan2.1图生视频加速2.4倍),但可能引入轻微背景模糊或细节丢失(如手部轮廓轻微变形)。
- 视觉影响规律:质量损失集中于非主体区域(如背景),主体(如人脸、物体)因后期步骤计算保留,受影响较小。
四、典型应用场景的加速机制
-
图像生成(如FLUX模型)
- 缓存早期生成的轮廓特征,跳过重复的全局结构计算。例如,30步生成中,通过阈值0.4可复用15步,加速近2倍,且主体清晰度保留。
-
视频生成(如Wan2.1、混元视频)
- 利用时序稳定性,缓存同一场景的背景、静态物体特征。例如,81帧视频中,每5帧复用一次背景缓存,减少约40%的计算量。
-
音频合成(如TangoFlux)
- 缓存稳定的基频和音色特征,跳过噪声相似的时间步,加速语音生成的连贯性。
五、与传统加速方法的区别
方法 | 原理 | 优缺点 |
---|---|---|
TeaCache | 动态缓存时间步输出 | 免训练、无损/轻微损,加速1.4-2.4倍 |
蒸馏 | 训练小模型替代原模型 | 需重新训练,质量损失显著 |
减少步数 | 直接截断时间步 | 简单粗暴,质量断崖式下降 |
原理示意图
时间步:t0 → t1 → t2 → t3 → ... → tN↓ ↓ ↓缓存 复用 缓存(当Δt<阈值)
TeaCache通过**“监测波动→智能缓存→动态复用”的闭环,在扩散模型的冗余计算中挖掘加速潜力**,实现了**“无需牺牲核心质量的高效推理”**,尤其适合对速度敏感的生成场景(如短视频量产、实时交互)。
符号说明:
- →:时间步推进方向
- ↓:当前时间步的计算操作
- 缓存(💾):当时间步输出差异 Δt > 阈值 时,计算并存储结果
- 复用(↩️):当时间步输出差异 Δt < 阈值 时,直接使用前序缓存结果,跳过计算
示例解析:
- t0:首个时间步,无缓存,强制计算并缓存结果(💾)。
- t1:计算前检测到与 t0 的差异 Δt < 阈值,复用 t0 缓存(↩️ t0),跳过当前计算。
- t2:与 t1 的差异 Δt > 阈值,重新计算并缓存结果(💾)。
- t3:检测到与 t2 的差异 Δt < 阈值,复用 t2 缓存(↩️ t2),跳过计算。
- 后续步骤:循环此逻辑,动态判断是否复用或缓存。
关键逻辑:
- 阈值控制:通过
rel_l1_thresh
参数定义差异容忍度,平衡加速与质量。 - 缓存节点:仅在差异较大的时间步(如 t0、t2)存储结果,减少内存占用。
- 计算跳过:差异小的时间步(如 t1、t3)直接复用缓存,节省算力。
TeaCache 在扩散模型中的核心实现
一、整体架构与输入输出
这是扩散模型的前向传播函数,支持两种模式:
- 文生视频(t2v):通过文本嵌入(
context
)生成视频 - 图生视频(i2v):通过 CLIP 图像特征(
clip_fea
)和条件视频(y
)生成视频
核心参数:
x
:输入视频张量列表t
:扩散时间步context
:文本嵌入clip_fea
/y
:图生视频模式的输入
输出:去噪后的视频张量列表
二、TeaCache 核心加速逻辑
TeaCache 的核心是通过时间步输出波动分析决定何时复用缓存:
if self.enable_teacache:# 1. 选择用于波动分析的输入(时间步嵌入)modulated_inp = e0 if self.use_ref_steps else e# 2. 奇偶时间步分别处理(条件/无条件分支)if self.cnt%2==0: # 偶时间步(条件分支)# 2.1 判断是否需要计算if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:should_calc_even = True # 起始/结束阶段强制计算else:# 2.2 计算当前与前一时间步的相对L1距离rel_l1 = ((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()# 2.3 使用多项式函数调整阈值(非线性优化)self.accumulated_rel_l1_distance_even += rescale_func(rel_l1)# 2.4 与阈值比较,决定是否复用缓存if self.accumulated_rel_l1_distance_even < self.teacache_thresh:should_calc_even = Falseelse:should_calc_even = True# 2.5 保存当前时间步的嵌入用于下次比较self.previous_e0_even = modulated_inp.clone()else: # 奇时间步(无条件分支),逻辑同上...# 3. 根据判断结果执行计算或复用缓存if self.is_even:if not should_calc_even:x += self.previous_residual_even # 复用缓存的残差else:# 执行完整计算并保存残差ori_x = x.clone()for block in self.blocks:x = block(x, **kwargs)self.previous_residual_even = x - ori_xelse:...
三、关键技术细节
-
波动分析指标:
- 使用 相对L1距离(
rel_l1
)衡量相邻时间步的差异:
rel_l1 = |当前嵌入 - 前一嵌入| / |前一嵌入|
- 累积距离与
teacache_thresh
比较,决定是否复用
- 使用 相对L1距离(
-
非线性阈值调整:
rescale_func = np.poly1d(self.coefficients)
- 使用多项式函数(如二次函数)动态调整阈值敏感度
- 早期时间步容忍度高(允许更多缓存),后期更严格(保留细节)
-
缓存策略:
- 强制计算区间:
self.cnt < self.ret_steps
(前几步)和self.cnt >= self.cutoff_steps
(后几步)强制计算,确保关键信息不丢失 - 残差缓存:
不直接缓存输出,而是缓存残差(previous_residual_even
),复用时代价更低
- 强制计算区间:
-
双分支处理:
- 偶时间步(
self.cnt%2==0
)处理条件分支(文本/图像引导) - 奇时间步处理无条件分支(去噪过程)
- 分别维护独立的缓存和距离累积
- 偶时间步(
四、TeaCache 与原模型的对比
# 未启用TeaCache时的原始计算
else:for block in self.blocks:x = block(x, **kwargs) # 每步都执行完整计算
启用TeaCache后,通过智能判断跳过部分计算,加速比取决于:
teacache_thresh
:阈值越高,跳过的计算越多- 内容复杂度:静态场景加速比更高(波动小)
Wan2.1 模型集成 TeaCache 加速的核心实现
主要用于文生视频(t2v)和图生视频(i2v)任务
一、整体流程概述
-
分布式环境初始化
- 处理多 GPU 分布式训练/推理配置(如
world_size
、local_rank
),初始化进程组(dist.init_process_group
)。 - 若使用并行策略(如
ulysses_size
/ring_size
),需确保参数与 GPU 数量匹配。
- 处理多 GPU 分布式训练/推理配置(如
-
提示词扩展(可选)
- 通过本地 Qwen 模型扩展输入提示词,增强生成内容的细节(如从“猫”扩展为“戴墨镜的白色猫咪在冲浪板上”)。
-
模型管道创建
- 根据任务类型(t2v/i2v)创建对应的 Wan2.1 模型管道(
WanT2V
/WanI2V
)。 - 核心步骤:将 TeaCache 机制注入模型管道,修改模型的前向传播逻辑。
- 根据任务类型(t2v/i2v)创建对应的 Wan2.1 模型管道(
二、TeaCache 集成的关键代码段
1. 启用 TeaCache 并替换前向传播函数
# 文生视频场景
wan_t2v.model.__class__.enable_teacache = True # 启用 TeaCache 开关
wan_t2v.model.__class__.forward = teacache_forward # 替换为带缓存的前向传播函数
- 作用:将原始模型的
forward
方法替换为teacache_forward
(即之前分析的 TeaCache 核心逻辑函数),使模型在推理时执行缓存复用策略。
2. 初始化 TeaCache 参数
wan_t2v.model.__class__.cnt = 0 # 时间步计数器
wan_t2v.model.__class__.num_steps = args.sample_steps*2 # 总时间步数(扩散步数×2,因分奇偶步处理)
wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh # 差异阈值(用户可调参数)# 累积距离和历史缓存初始化
wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0
wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0
wan_t2v.model.__class__.previous_e0_even = None
wan_t2v.model.__class__.previous_e0_odd = None
wan_t2v.model.__class__.previous_residual_even = None
wan_t2v.model.__class__.previous_residual_odd = None
- 关键参数:
teacache_thresh
:控制缓存复用的阈值(如 0.15 表示允许相对 L1 距离累积不超过 0.15)。cnt
:记录当前处理的时间步,用于区分奇偶步(条件/无条件分支)。previous_residual
:缓存残差(当前步与前一步的输出差异),复用时直接叠加残差以减少计算量。
3. 多项式系数配置(自适应阈值调整)
if args.use_ret_steps: # 是否使用固定保留步数(起始/结束阶段强制计算)if '1.3B' in args.ckpt_dir:# 1.3B 模型的多项式系数(二次函数或更高次)wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, ...]elif '14B' in args.ckpt_dir:# 14B 模型的系数(更大模型需要更敏感的阈值调整)wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, ...]wan_t2v.model.__class__.ret_steps = 5*2 # 前 10 步强制计算(5 步×2 奇偶分支)wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 # 最后阶段强制计算
else:# 不使用保留步数时的默认系数(适用于轻量级场景)...
- 作用:
通过多项式函数(如y = ax² + bx + c
)动态调整阈值敏感度。- 早期时间步(如前 10 步)允许更大波动(快速生成轮廓),系数设置为高容忍度。
- 后期时间步(如最后 10 步)需保留细节,系数调整为低容忍度,减少缓存复用。
4. 模型类型与参数适配
# 图生视频(i2v)场景的系数配置(根据分辨率调整)
if '480P' in args.ckpt_dir:wan_i2v.model.__class__.coefficients = [2.57151496e+05, ...] # 低分辨率模型系数
elif '720P' in args.ckpt_dir:wan_i2v.model.__class__.coefficients = [8.10705460e+03, ...] # 高分辨率模型系数
- 逻辑:
高分辨率模型(如 720P)需要更精细的细节控制,因此系数设置更严格,减少后期步骤的缓存复用,避免模糊。
三、TeaCache 加速的核心逻辑串联
-
时间步奇偶分支处理
- 偶步(条件分支):处理文本/图像引导信号(如文本嵌入
context
、CLIP 特征clip_fea
),波动较小,适合缓存。 - 奇步(无条件分支):纯去噪过程,波动较大,缓存复用需更谨慎。
- 偶步(条件分支):处理文本/图像引导信号(如文本嵌入
-
强制计算区间
ret_steps
:前5*2=10
步强制计算,确保初始特征稳定(避免早期随机噪声导致的缓存错误)。cutoff_steps
:最后args.sample_steps*2
步强制计算,保留细节生成(如发丝、光影)。
-
残差缓存机制
- 不直接缓存完整输出,而是缓存残差(
x - ori_x
),复用时仅需叠加残差,而非重新计算整个 Transformer 块,进一步降低计算量。
- 不直接缓存完整输出,而是缓存残差(
内置属性__class__
在 Python 中,__class__
是一个内置属性,用于表示对象所属的类(即创建该对象的类)。它是面向对象编程中类与实例关系的核心概念之一,常用于类型检查、动态绑定方法或属性等场景。
__class__
是 Python 中连接对象与类的桥梁,核心作用包括:
- 获取对象所属的类。
- 在动态编程中修改类的行为(如替换方法、添加属性)。
- 在继承体系中追溯类的层级关系。
一、基础用法:获取对象的类
class MyClass:def __init__(self):self.value = 42obj = MyClass()
print(obj.__class__) # 输出:<class '__main__.MyClass'>
print(obj.__class__.__name__) # 输出:'MyClass'(类名)
- 作用:每个实例对象都有
__class__
属性,指向其对应的类对象。 - 等价写法:
type(obj)
与obj.__class__
通常返回相同结果(除非自定义了__class__
,但这种情况极少见)。
二、在继承中的表现
class Parent:passclass Child(Parent):passchild = Child()
print(child.__class__) # <class '__main__.Child'>
print(child.__class__.__base__) # <class '__main__.Parent'>(父类)
- 特性:子类实例的
__class__
指向子类本身,而非父类。 - 用途:可通过
__class__
追溯类的继承链(如obj.__class__.__mro__
查看方法解析顺序)。
三、动态绑定类属性或方法
在 Python 中,类和实例的属性可以动态修改。__class__
常用于在实例层面修改类的行为,例如:
class Dog:def bark(self):print("Woof!")# 创建实例
dog = Dog()
dog.bark() # 输出:Woof!# 通过 __class__ 动态修改类方法(所有实例生效)
def new_bark(self):print("Arf!")Dog.bark = new_bark # 直接修改类方法
dog.bark() # 输出:Arf!# 或通过实例的 __class__ 修改(仅影响该实例的类)
dog.__class__.bark = new_bark # 等价于上述操作
四、在 TeaCache 代码中的实际应用
在之前的代码中,TeaCache 通过 __class__
动态替换模型的前向传播方法:
wan_t2v.model.__class__.forward = teacache_forward
- 逻辑解析:
wan_t2v.model
是一个模型实例(如 Wan2.1 模型)。wan_t2v.model.__class__
获取该实例的类(如WanT2VModel
)。- 将类的
forward
方法替换为teacache_forward
函数,使所有该类的实例(包括后续创建的实例)都使用新的前向逻辑。
- 作用:实现了 免修改原模型代码 的加速逻辑注入,属于 运行时动态编程 技巧。
五、注意事项
-
只读性:
虽然__class__
属性通常可写,但不建议修改实例的__class__
(可能导致类型混乱)。例如:obj.__class__ = AnotherClass # 不推荐!可能引发不可预期的错误
-
与
type()
的区别:type(obj)
返回对象的类型,等价于obj.__class__
。type()
是函数,__class__
是属性,但最终效果一致:assert type(obj) is obj.__class__ # 通常为 True