[e3nn] 模型部署 | TorchScript JIT | `@compile_mode`装饰器 | Cython
第6章:TorchScript JIT支持
欢迎来到e3nn
的最终章~
在第5章:归一化中,我们学习了如何通过特征缩放使e3nn
模型在训练过程中保持稳定。
现在,让我们讨论如何让训练好的模型变得快速并准备好部署。
假设您已经使用e3nn
构建了一个出色的等变神经网络,可以高精度预测分子性质。
您希望在生产环境中使用这个模型,可能是在移动设备
、云服务器
上,或作为大型高性能模拟的一部分。在这些场景中,模型的速度和效率
至关重要。
这就是TorchScript JIT支持的用武之地。PyTorch提供了一个名为TorchScript和即时(JIT)编译的功能,允许将Python PyTorch模型转换为优化的序列化表示
。
这个编译版本运行更快,并且可以在不需要完整Python环境的情况下部署。
TorchScript为e3nn
解决什么问题?
TorchScript的目标是优化和序列化PyTorch模型。对于大多数标准PyTorch模型来说,这很简单。
然而,e3nn
模型提出了独特的挑战:
-
动态与静态操作:
e3nn
模块通常混合两种操作:- 静态图操作:许多核心
e3nn
操作,如张量积中的操作,是计算密集的张量操作(如torch.einsum
)。一旦Irreps
已知,这些操作的确切序列就固定了。 - 数据相关控制流:
e3nn
模型的其他部分可能包含Pythonif
语句或for
循环,其执行路径取决于运行时张量的实际值(例如,“如果这个范数超过某个阈值,执行X,否则执行Y”)。
- 静态图操作:许多核心
-
TorchScript的两种模式:PyTorch的JIT编译有两种主要模式,每种适用于不同场景:
- 追踪(
torch.jit.trace
):记录模型在示例输入上运行时的操作。它创建一个静态计算图
。这对性能很好,但不能处理数据相关控制流(例如,条件随输入值变化的if
语句)。 - 脚本化(
torch.jit.script
):直接分析Python源代码并将其编译为TorchScript。这可以处理数据相关控制流,但仅支持Python特性的子集(例如,它不总是能很好地处理复杂继承或动态类创建)。
- 追踪(
e3nn
的挑战在于其模块通常需要两者。例如,像e3nn.nn.Gate
这样的模块可能包含数据相关的if
语句(需要脚本化),但其内部可能使用e3nn.o3.TensorProduct
(它依赖于基于Irreps
的e3nn
内部代码生成,更适合追踪或专门处理)。直接对e3nn
模型应用标准torch.jit.script
或torch.jit.trace
通常会导致错误。
为了解决这个问题,e3nn
提供了**e3nn.util.jit
**,这是一组专门处理编译过程的工具。
它自动找出编译e3nn
模型不同部分的最佳方式,根据需要混合脚本化和追踪
,并向用户隐藏这种复杂性。
🎢@compile_mode
装饰器
@compile_mode
是Python中一个用于控制代码编译行为的装饰器,通常与元编程或代码生成技术结合使用。它能够修改函数或类的编译方式,例如改变字节码生成规则或启用特定优化。
常见场景:
- 代码优化:通过装饰器标记需要特殊编译优化的函数,例如
禁用断言
或内联
特定操作。 - 语法扩展:配合
自定义解析器
实现非标准语法(如领域特定语言)。 - 调试工具:在开发阶段
动态注入调试代码
或性能分析逻辑。
code:
def compile_mode(optimize=False):def decorator(func):func.__compile_mode__ = {'optimize': optimize}return funcreturn decorator@compile_mode(optimize=True)
def heavy_calculation():# 该函数会被标记为需要优化编译return sum(x*x for x in range(10**6))
注意事项:
- 该装饰器并非Python标准库内置功能,通常由第三方库(如Numba或PyPy)提供实现
- 具体效果取决于底层编译器或解释器的支持程度
- 过度使用可能导致代码可读性下降和调试困难
典型实现库:
- Numba:通过
@jit
装饰器实现即时编译优化 - Cython:使用
@cython.compile
执行C语言转换 - PyPy:利用RLPython实现的特殊编译模式
🎢Cython
Cython 是一种编程语言,通过将 Python 代码编译为 C 语言来提升性能。核心功能包括类型声明和直接调用 C 库
。
@cython.compile
是实验性装饰器,可自动编译函数为 C 扩展模块。
以下代码展示如何用 @cython.compile
加速斐波那契数列计算:
import cython@cython.compile
def fib(n: cython.int) -> cython.int:a: cython.int = 0b: cython.int = 1for _ in range(n):a, b = b, a + breturn aprint(fib(10)) # 输出55
解析 :
- 类型声明:使用
cython.int
明确变量类型,避免 Python 动态类型开销。 - 循环优化:
for
循环会被编译为 C 级高效实现。 - 装饰器作用:
@cython.compile
自动生成 C 代码并编译为.so
或.pyd
文件。
性能对比 :
- 未优化 Python 实现可能慢 10-100 倍。Cython 通过静态类型和直接编译到 C 显著减少函数调用开销。
注意事项:
- 需安装 Cython 包:
pip install cython
- 实验性功能可能不稳定,生产环境建议用
.pyx
文件结合setup.py
编译。
e3nn.util.jit
的关键概念
让我们看看使e3nn
支持JIT编译的核心思想:
-
@compile_mode
装饰器:- 这是
e3nn
模块(如e3nn.o3.Linear
或e3nn.nn.Gate
)告诉e3nn.util.jit
它们希望如何被编译的方式。 - 这是一个Python装饰器,您可以在类定义上方添加:
@compile_mode('script')
:表示此模块希望使用脚本化编译(适用于数据相关控制流)。@compile_mode('trace')
:表示此模块希望使用追踪编译(适用于静态操作)。@compile_mode('unsupported')
:表示此模块无法编译为TorchScript。
- 这是
-
递归编译:
e3nn.util.jit
函数(script
、trace
、compile
)递归工作。当您请求编译主模块时,e3nn.util.jit
首先查看其子模块。- 它根据每个子模块的
@compile_mode
装饰器编译每个子模块。一旦子模块被编译(无论是脚本化还是追踪),它就被替换为其TorchScript版本。 - 处理完所有子模块后,主模块本身被编译,平滑集成已编译的部分。这对于处理复杂模块层次结构和混合编译模式至关重要。
如何使用e3nn.util.jit
我们主要通过e3nn.util.jit
的script
、trace
和compile
函数与其交互,这些函数作为PyTorch原生JIT编译函数的智能包装器。
示例1:使用e3nn.util.jit.script
脚本化模块
让我们定义一个简单模块,在其forward
方法中包含数据相关控制流,并在内部使用e3nn.o3.Norm
。e3nn.o3.Norm
用于计算特征向量的范数,这是e3nn
模型中的常见操作。
import torch
from e3nn.o3 import Norm, Irreps
from e3nn.util.jit import script, trace, compile_modeclass MyModule(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()# Norm是一个专门的TensorProduct,直接脚本化很复杂self.norm = Norm(irreps_in)def forward(self, x):norm = self.norm(x)# 这是数据相关控制流:它取决于'norm'的值if torch.any(norm > 7.):return normelse:return norm * 0.5# 定义输入irreps(例如,两个标量和一个向量)
irreps = Irreps("2x0e + 1x1o")
mod = MyModule(irreps)print(f"原始模块:{mod}")
现在,尝试使用标准torch.jit.script
编译:
try:# 这可能会失败!mod_script_fail = torch.jit.script(mod)
except Exception as e:print(f"\n标准torch.jit.script失败:{e.__class__.__name__}: {e}")
输出(会变化但通常指示TorchScript限制):
标准torch.jit.script失败:TorchScriptError: 无法脚本化一个对象,该对象是定义了`__getstate__`或`__setstate__`的类型的子类,除非该对象是ScriptModule。
错误消息指向Norm
(它内部继承自TensorProduct
和CodeGenMixin
),特别是直接脚本化时的__getstate__
和__setstate__
或其他不支持的Python特性问题。
这就是e3nn.util.jit.script
来救援的时候
# 使用e3nn的智能脚本化函数
mod_script = script(mod)
print(f"\n使用e3nn.util.jit.script成功编译:{type(mod_script)}")# 验证输出相同
x = irreps.randn(2, -1) # 创建一些随机输入数据
assert torch.allclose(mod(x), mod_script(x))
print("原始模块和编译模块产生相同结果。")
输出:
使用e3nn.util.jit.script成功编译:<class 'torch.jit.ScriptModule'>
原始模块和编译模块产生相同结果。
解释:e3nn.util.jit.script
首先识别MyModule
的self.norm
是一个e3nn.o3.Norm
模块。
Norm
(如e3nn.o3.TensorProduct
)内部设计为由追踪处理,因为其内部结构在基于Irreps
初始化后是静态的。
-
因此,
e3nn.util.jit.script
首先追踪self.norm
,有效地将其转换为TorchScript图。 -
然后,它继续脚本化
MyModule
本身,此时self.norm
已经是torch.jit.script
可以轻松集成的有效TorchScript子模块。
这种递归、混合模式的编译是e3nn.util.jit
如此强大的原因。
示例2:显式混合追踪和脚本化
也可以使用@compile_mode
装饰器显式告诉e3nn
如何编译自定义模块。
让我们创建一个显式标记为script
模式的MyModule
,并将其嵌入到我们将要trace
的AnotherModule
中。
# 显式标记MyModule为脚本化
@compile_mode('script')
class MyModuleWithMode(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()self.norm = Norm(irreps_in)def forward(self, x):norm = self.norm(x)# 数据相关控制流for row in norm:if torch.any(row > 0.1):return rowreturn normclass AnotherModule(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()self.mymod = MyModuleWithMode(irreps_in) # 包含脚本化模块def forward(self, x):return self.mymod(x) + 3.irreps = Irreps("2x0e + 1x1o")
mod2 = AnotherModule(irreps)
print(f"原始AnotherModule:{mod2}")# 我们追踪AnotherModule,但其子模块MyModuleWithMode将被脚本化
example_inputs = (irreps.randn(3, -1),) # 追踪的示例输入
mod2_traced = trace(mod2, example_inputs)
print(f"\n追踪的AnotherModule:{type(mod2_traced)}")# 用不同输入测试追踪模块
print("\n全零输出:")
print(mod2_traced(torch.zeros(2, irreps.dim)))
print("\n随机输入输出:")
print(mod2_traced(irreps.randn(3, -1)))
输出(值会因随机输入而变化):
追踪的AnotherModule:<class 'torch.jit.ScriptModule'>全零输出:
tensor([3., 3., 3., 3., 3.])随机输入输出:
tensor([3.9114, 3.9114, 3.9114, 3.9114, 3.9114])
注意,即使AnotherModule
本身被追踪(trace
返回一个嵌入图的ScriptModule
),MyModuleWithMode
中的数据相关控制流(其for
循环和if
语句)被保留。通过检查模块类型可以确认这一点:
print(f"mod2_traced类型:{type(mod2_traced)}")
print(f"mod2_traced.mymod类型:{type(mod2_traced.mymod)}")
输出:
mod2_traced类型:<class 'torch.jit.ScriptModule'>
mod2_traced.mymod类型:<class 'torch.jit.ScriptModule'>
在这种情况下,两者都变成了ScriptModule
。
关键是e3nn.util.jit.trace
首先递归使用script
编译MyModuleWithMode
(因为其装饰器),然后编译AnotherModule
。
如果完全捕获包括子模块在内的所有图操作,torch.jit.trace_module
可以返回ScriptModule
。
示例3:使用_make_tracing_inputs
自定义追踪输入
有时,e3nn
无法自动猜测如何生成追踪的示例输入,特别是如果您的模块接受标准特征张量之外的复杂参数(例如,整数索引、张量列表)。
在这种情况下,您可以在模块中定义_make_tracing_inputs
方法。
@compile_mode('trace') # 我们希望追踪此模块
class TracingModule(torch.nn.Module):def forward(self, x: torch.Tensor, indexes: torch.LongTensor):return x[indexes].sum()# 此方法告诉e3nn.util.jit如何生成追踪的示例输入def _make_tracing_inputs(self, n: int):import random# 'n'是生成多少示例输入的建议# 我们需要返回字典列表,每个字典将方法名(如'forward')映射到示例参数元组return [{'forward': (torch.randn(5, random.randint(1, 3)), # 随机张量'x'torch.arange(random.randint(1, 3)) # 随机LongTensor'indexes')}for _ in range(n)]from e3nn.util.jit import compile # 直接使用`compile`以尊重装饰器mod3 = TracingModule()
mod3_traced = compile(mod3) # 这里不需要显式输入,_make_tracing_inputs提供它们
print(f"编译的TracingModule类型:{type(mod3_traced)}")# 测试追踪模块
test_x = torch.randn(5, 2)
test_indexes = torch.tensor([0, 1, 3]) # 注意:索引必须对test_x有效
print(f"输入x形状:{test_x.shape}, 索引:{test_indexes}")
print(f"追踪模块输出:{mod3_traced(test_x, test_indexes)}")
输出(值会变化):
编译的TracingModule类型:<class 'torch.jit.ScriptModule'>
输入x形状:torch.Size([5, 2]), 索引:tensor([0, 1, 3])
追踪模块输出:2.1580798625946045
解释:当e3nn.util.jit.compile
尝试追踪TracingModule
时,它看到_make_tracing_inputs
方法。
它不尝试从irreps_in
猜测输入(因为TracingModule
没有indexes
的irreps_in
),而是调用_make_tracing_inputs
获取有效示例输入,从而成功追踪。
示例4:compile_mode("unsupported")
如果您有一个自定义模块使用完全与TorchScript不兼容的高级Python特性(即使有e3nn.util.jit
的帮助),您可以将其标记为"unsupported"
。
这提供了更清晰的错误消息,如果有人尝试编译它。
@compile_mode('unsupported')
class ChildMod(torch.nn.Module):# 此模块故意不可脚本化def forward(self, x):# 想象这里有一些复杂的动态Python逻辑passclass Supermod(torch.nn.Module):def __init__(self) -> None:super().__init__()self.child = ChildMod()mod_unsupported = Supermod()try:script(mod_unsupported)
except NotImplementedError as e:print(f"\n捕获预期错误:{e}")
输出:
捕获预期错误:ChildMod不支持TorchScript编译
解释:e3nn.util.jit
正确识别"unsupported"
模式并引发有用的NotImplementedError
,而不是晦涩的TorchScript错误。
内部机制:e3nn.util.jit
如何工作
e3nn.util.jit
的魔法主要在其compile
函数中(位于e3nn/util/jit.py
)。此函数编排递归编译过程。
以下是e3nn.util.jit.compile
如何工作的简化序列:
- 获取
compile_mode
:当在模块上调用e3nn.util.jit.compile
时,它首先检查模块(及其类型)是否有@compile_mode
装饰器设置的_E3NN_COMPILE_MODE
属性。 - 递归步骤:如果
recurse=True
(这是默认且推荐的),e3nn.util.jit.compile
遍历所有直接子模块。对于每个子模块,它递归调用自身。这确保首先编译内部模块。 - 编译:一旦递归展开(意味着当前模块的所有子模块已经编译),
e3nn.util.jit.compile
继续编译当前模块本身:- 如果设置了
@compile_mode('script')
,它调用torch.jit.script
。 - 如果设置了
@compile_mode('trace')
,它首先生成示例输入(如果可用则使用_make_tracing_inputs
,或通过推断irreps_in
),然后调用torch.jit.trace_module
。 - 如果是
@compile_mode('unsupported')
,它引发NotImplementedError
。
- 如果设置了
- 替换:原始Python模块被其TorchScript编译版本替换,保留在父模块结构中。
这种系统方法允许e3nn
处理其等变神经网络层中存在的动态和静态行为的复杂混合,生成统一的优化TorchScript模型。
为自定义模块选择'script'
和'trace'
当您构建自己的自定义e3nn
模块并考虑如何用@compile_mode
注释它们时,这里有一个简单指南:
特性/行为 | @compile_mode('script') | @compile_mode('trace') |
---|---|---|
控制流 | 支持数据相关if /for 语句(例如,if x.mean() > 0: )。 | 不支持数据相关控制流。仅记录示例输入所采取的路径。 |
Python特性 | Python子集有限。继承、动态类创建有问题。 | 如果forward 是静态的,与一般Python兼容性更好。 |
e3nn 上下文 | 适用于基于实际特征值有条件逻辑的模块。例如:e3nn.nn.Gate 。 | 适用于一旦__init__ 中Irreps 已知forward 图就固定的模块,即使使用复杂Irreps 逻辑构建图。例如:e3nn.o3.TensorProduct ,e3nn.o3.Linear 。 |
通常,e3nn
的内置模块已经适当装饰。编写自己的模块时,如果有任何依赖于张量值的if
/for
语句,首先尝试@compile_mode('script')
。如果遇到错误,考虑是否可以重写forward
方法使其更静态,然后尝试@compile_mode('trace')
。如果e3nn.util.jit
本身失败,那么_make_tracing_inputs
是您追踪的工具。
结论
本文介绍了如何通过TorchScript JIT编译优化e3nn等变神经网络模型的部署性能。
e3nn模型包含静态张量操作和动态控制流,给标准JIT编译带来挑战。e3nn.util.jit提供了智能解决方案:
通过@compile_mode装饰器标记模块编译方式(script/trace/unsupported)
采用递归编译策略,自动混合脚本化和追踪模式
提供script(),trace(),compile()等接口简化编译过程
-
示例展示了如何成功编译包含数据相关控制流的e3nn模块,并验证编译前后模型输出一致。
-
最后演示了显式混合使用追踪和脚本化模式的方法。这些技术使e3nn模型能在生产环境中高效运行。
TorchScript JIT支持是e3nn
中的一个强大功能,允许获取训练好的等变神经网络并优化它们以进行高性能部署。通过智能结合PyTorch的scripting
和tracing
功能,e3nn.util.jit
无缝处理等变模型的独特复杂性。
我们已经学习了e3nn
如何使用@compile_mode
装饰器及其script
、trace
和compile
函数来管理此过程,包括在需要时如何提供自定义追踪输入。有了这些知识,可以确保e3nn
模型不仅智能且对称,而且快速并准备好用于实际应用。
这结束了我们通过e3nn
核心概念的教程之旅。希望你现在可以拥有探索、构建和创新等变神经网络的坚实基础啦
END ★,°:.☆( ̄▽ ̄).°★* 。