当前位置: 首页 > ops >正文

[e3nn] 模型部署 | TorchScript JIT | `@compile_mode`装饰器 | Cython

第6章:TorchScript JIT支持

欢迎来到e3nn的最终章~

在第5章:归一化中,我们学习了如何通过特征缩放使e3nn模型在训练过程中保持稳定。

现在,让我们讨论如何让训练好的模型变得快速准备好部署

假设您已经使用e3nn构建了一个出色的等变神经网络,可以高精度预测分子性质

您希望在生产环境中使用这个模型,可能是在移动设备云服务器上,或作为大型高性能模拟的一部分。在这些场景中,模型的速度和效率至关重要。

这就是TorchScript JIT支持的用武之地。PyTorch提供了一个名为TorchScript即时(JIT)编译的功能,允许将Python PyTorch模型转换为优化的序列化表示

这个编译版本运行更快,并且可以在不需要完整Python环境的情况下部署。

TorchScript为e3nn解决什么问题?

TorchScript的目标是优化和序列化PyTorch模型。对于大多数标准PyTorch模型来说,这很简单。

然而,e3nn模型提出了独特的挑战:

  1. 动态与静态操作e3nn模块通常混合两种操作:

    • 静态图操作:许多核心e3nn操作,如张量积中的操作,是计算密集的张量操作(如torch.einsum)。一旦Irreps已知,这些操作的确切序列就固定了。
    • 数据相关控制流e3nn模型的其他部分可能包含Python if语句或for循环,其执行路径取决于运行时张量的实际(例如,“如果这个范数超过某个阈值,执行X,否则执行Y”)。
  2. TorchScript的两种模式:PyTorch的JIT编译有两种主要模式,每种适用于不同场景:

    • 追踪(torch.jit.trace):记录模型在示例输入上运行时的操作。它创建一个静态计算图。这对性能很好,但不能处理数据相关控制流(例如,条件随输入值变化的if语句)。
    • 脚本化(torch.jit.script):直接分析Python源代码并将其编译为TorchScript。这可以处理数据相关控制流,但仅支持Python特性的子集(例如,它不总是能很好地处理复杂继承或动态类创建)。

e3nn的挑战在于其模块通常需要两者。例如,像e3nn.nn.Gate这样的模块可能包含数据相关的if语句(需要脚本化),但其内部可能使用e3nn.o3.TensorProduct(它依赖于基于Irrepse3nn内部代码生成,更适合追踪或专门处理)。直接对e3nn模型应用标准torch.jit.scripttorch.jit.trace通常会导致错误。

为了解决这个问题,e3nn提供了**e3nn.util.jit**,这是一组专门处理编译过程的工具。

自动找出编译e3nn模型不同部分的最佳方式,根据需要混合脚本化和追踪,并向用户隐藏这种复杂性。


🎢@compile_mode装饰器

@compile_mode是Python中一个用于控制代码编译行为的装饰器,通常与元编程或代码生成技术结合使用。它能够修改函数或类的编译方式,例如改变字节码生成规则或启用特定优化。

常见场景:

  1. 代码优化:通过装饰器标记需要特殊编译优化的函数,例如禁用断言内联特定操作。
  2. 语法扩展:配合自定义解析器实现非标准语法(如领域特定语言)。
  3. 调试工具:在开发阶段动态注入调试代码或性能分析逻辑。

code:

def compile_mode(optimize=False):def decorator(func):func.__compile_mode__ = {'optimize': optimize}return funcreturn decorator@compile_mode(optimize=True)
def heavy_calculation():# 该函数会被标记为需要优化编译return sum(x*x for x in range(10**6))

注意事项:

  • 该装饰器并非Python标准库内置功能,通常由第三方库(如Numba或PyPy)提供实现
  • 具体效果取决于底层编译器或解释器的支持程度
  • 过度使用可能导致代码可读性下降和调试困难

典型实现库:

  1. Numba:通过@jit装饰器实现即时编译优化
  2. Cython:使用@cython.compile执行C语言转换
  3. PyPy:利用RLPython实现的特殊编译模式

🎢Cython

Cython 是一种编程语言,通过将 Python 代码编译为 C 语言来提升性能。核心功能包括类型声明和直接调用 C 库

@cython.compile 是实验性装饰器,可自动编译函数为 C 扩展模块

以下代码展示如何用 @cython.compile 加速斐波那契数列计算:

import cython@cython.compile
def fib(n: cython.int) -> cython.int:a: cython.int = 0b: cython.int = 1for _ in range(n):a, b = b, a + breturn aprint(fib(10))  # 输出55

解析 :

  1. 类型声明:使用 cython.int 明确变量类型,避免 Python 动态类型开销。
  2. 循环优化for 循环会被编译为 C 级高效实现。
  3. 装饰器作用@cython.compile 自动生成 C 代码并编译为 .so.pyd 文件。

性能对比 :

  • 未优化 Python 实现可能慢 10-100 倍。Cython 通过静态类型和直接编译到 C 显著减少函数调用开销。

注意事项:

  • 需安装 Cython 包:pip install cython
  • 实验性功能可能不稳定,生产环境建议用 .pyx 文件结合 setup.py 编译。

e3nn.util.jit的关键概念

让我们看看使e3nn支持JIT编译的核心思想:

  1. @compile_mode装饰器

    • 这是e3nn模块(如e3nn.o3.Lineare3nn.nn.Gate)告诉e3nn.util.jit它们希望如何被编译的方式。
    • 这是一个Python装饰器,您可以在类定义上方添加:
      • @compile_mode('script'):表示此模块希望使用脚本化编译(适用于数据相关控制流)。
      • @compile_mode('trace'):表示此模块希望使用追踪编译(适用于静态操作)。
      • @compile_mode('unsupported'):表示此模块无法编译为TorchScript。
  2. 递归编译

    • e3nn.util.jit函数(scripttracecompile)递归工作。当您请求编译主模块时,e3nn.util.jit首先查看其子模块。
    • 它根据每个子模块的@compile_mode装饰器编译每个子模块。一旦子模块被编译(无论是脚本化还是追踪),它就被替换为其TorchScript版本。
    • 处理完所有子模块后,主模块本身被编译,平滑集成已编译的部分。这对于处理复杂模块层次结构和混合编译模式至关重要。

如何使用e3nn.util.jit

我们主要通过e3nn.util.jitscripttracecompile函数与其交互,这些函数作为PyTorch原生JIT编译函数的智能包装器

示例1:使用e3nn.util.jit.script脚本化模块

让我们定义一个简单模块,在其forward方法中包含数据相关控制流,并在内部使用e3nn.o3.Norme3nn.o3.Norm用于计算特征向量的范数,这是e3nn模型中的常见操作。

import torch
from e3nn.o3 import Norm, Irreps
from e3nn.util.jit import script, trace, compile_modeclass MyModule(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()# Norm是一个专门的TensorProduct,直接脚本化很复杂self.norm = Norm(irreps_in)def forward(self, x):norm = self.norm(x)# 这是数据相关控制流:它取决于'norm'的值if torch.any(norm > 7.):return normelse:return norm * 0.5# 定义输入irreps(例如,两个标量和一个向量)
irreps = Irreps("2x0e + 1x1o")
mod = MyModule(irreps)print(f"原始模块:{mod}")

现在,尝试使用标准torch.jit.script编译:

try:# 这可能会失败!mod_script_fail = torch.jit.script(mod)
except Exception as e:print(f"\n标准torch.jit.script失败:{e.__class__.__name__}: {e}")

输出(会变化但通常指示TorchScript限制):

标准torch.jit.script失败:TorchScriptError: 无法脚本化一个对象,该对象是定义了`__getstate__`或`__setstate__`的类型的子类,除非该对象是ScriptModule。

错误消息指向Norm(它内部继承自TensorProductCodeGenMixin),特别是直接脚本化时的__getstate____setstate__或其他不支持的Python特性问题。

这就是e3nn.util.jit.script来救援的时候

# 使用e3nn的智能脚本化函数
mod_script = script(mod)
print(f"\n使用e3nn.util.jit.script成功编译:{type(mod_script)}")# 验证输出相同
x = irreps.randn(2, -1) # 创建一些随机输入数据
assert torch.allclose(mod(x), mod_script(x))
print("原始模块和编译模块产生相同结果。")

输出:

使用e3nn.util.jit.script成功编译:<class 'torch.jit.ScriptModule'>
原始模块和编译模块产生相同结果。

解释e3nn.util.jit.script首先识别MyModuleself.norm是一个e3nn.o3.Norm模块。

Norm(如e3nn.o3.TensorProduct)内部设计为由追踪处理,因为其内部结构在基于Irreps初始化后是静态的。

  • 因此,e3nn.util.jit.script首先追踪self.norm,有效地将其转换为TorchScript图。

  • 然后,它继续脚本化MyModule本身,此时self.norm已经是torch.jit.script可以轻松集成的有效TorchScript子模块。

这种递归、混合模式的编译是e3nn.util.jit如此强大的原因。


示例2:显式混合追踪和脚本化

也可以使用@compile_mode装饰器显式告诉e3nn如何编译自定义模块

让我们创建一个显式标记为script模式的MyModule,并将其嵌入到我们将要traceAnotherModule中。

# 显式标记MyModule为脚本化
@compile_mode('script')
class MyModuleWithMode(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()self.norm = Norm(irreps_in)def forward(self, x):norm = self.norm(x)# 数据相关控制流for row in norm:if torch.any(row > 0.1):return rowreturn normclass AnotherModule(torch.nn.Module):def __init__(self, irreps_in) -> None:super().__init__()self.mymod = MyModuleWithMode(irreps_in) # 包含脚本化模块def forward(self, x):return self.mymod(x) + 3.irreps = Irreps("2x0e + 1x1o")
mod2 = AnotherModule(irreps)
print(f"原始AnotherModule:{mod2}")# 我们追踪AnotherModule,但其子模块MyModuleWithMode将被脚本化
example_inputs = (irreps.randn(3, -1),) # 追踪的示例输入
mod2_traced = trace(mod2, example_inputs)
print(f"\n追踪的AnotherModule:{type(mod2_traced)}")# 用不同输入测试追踪模块
print("\n全零输出:")
print(mod2_traced(torch.zeros(2, irreps.dim)))
print("\n随机输入输出:")
print(mod2_traced(irreps.randn(3, -1)))

输出(值会因随机输入而变化):

追踪的AnotherModule:<class 'torch.jit.ScriptModule'>全零输出:
tensor([3., 3., 3., 3., 3.])随机输入输出:
tensor([3.9114, 3.9114, 3.9114, 3.9114, 3.9114])

注意,即使AnotherModule本身被追踪(trace返回一个嵌入图的ScriptModule),MyModuleWithMode中的数据相关控制流(其for循环和if语句)被保留。通过检查模块类型可以确认这一点:

print(f"mod2_traced类型:{type(mod2_traced)}")
print(f"mod2_traced.mymod类型:{type(mod2_traced.mymod)}")

输出:

mod2_traced类型:<class 'torch.jit.ScriptModule'>
mod2_traced.mymod类型:<class 'torch.jit.ScriptModule'>

在这种情况下,两者都变成了ScriptModule

关键是e3nn.util.jit.trace首先递归使用script编译MyModuleWithMode(因为其装饰器),然后编译AnotherModule

如果完全捕获包括子模块在内的所有图操作,torch.jit.trace_module可以返回ScriptModule


示例3:使用_make_tracing_inputs自定义追踪输入

有时,e3nn无法自动猜测如何生成追踪的示例输入,特别是如果您的模块接受标准特征张量之外的复杂参数(例如,整数索引、张量列表)。

在这种情况下,您可以在模块中定义_make_tracing_inputs方法。

@compile_mode('trace') # 我们希望追踪此模块
class TracingModule(torch.nn.Module):def forward(self, x: torch.Tensor, indexes: torch.LongTensor):return x[indexes].sum()# 此方法告诉e3nn.util.jit如何生成追踪的示例输入def _make_tracing_inputs(self, n: int):import random# 'n'是生成多少示例输入的建议# 我们需要返回字典列表,每个字典将方法名(如'forward')映射到示例参数元组return [{'forward': (torch.randn(5, random.randint(1, 3)), # 随机张量'x'torch.arange(random.randint(1, 3))    # 随机LongTensor'indexes')}for _ in range(n)]from e3nn.util.jit import compile # 直接使用`compile`以尊重装饰器mod3 = TracingModule()
mod3_traced = compile(mod3) # 这里不需要显式输入,_make_tracing_inputs提供它们
print(f"编译的TracingModule类型:{type(mod3_traced)}")# 测试追踪模块
test_x = torch.randn(5, 2)
test_indexes = torch.tensor([0, 1, 3]) # 注意:索引必须对test_x有效
print(f"输入x形状:{test_x.shape}, 索引:{test_indexes}")
print(f"追踪模块输出:{mod3_traced(test_x, test_indexes)}")

输出(值会变化):

编译的TracingModule类型:<class 'torch.jit.ScriptModule'>
输入x形状:torch.Size([5, 2]), 索引:tensor([0, 1, 3])
追踪模块输出:2.1580798625946045

解释:当e3nn.util.jit.compile尝试追踪TracingModule时,它看到_make_tracing_inputs方法。

它不尝试从irreps_in猜测输入(因为TracingModule没有indexesirreps_in),而是调用_make_tracing_inputs获取有效示例输入,从而成功追踪。

示例4:compile_mode("unsupported")

如果您有一个自定义模块使用完全与TorchScript不兼容的高级Python特性(即使有e3nn.util.jit的帮助),您可以将其标记为"unsupported"

这提供了更清晰的错误消息,如果有人尝试编译它。

@compile_mode('unsupported')
class ChildMod(torch.nn.Module):# 此模块故意不可脚本化def forward(self, x):# 想象这里有一些复杂的动态Python逻辑passclass Supermod(torch.nn.Module):def __init__(self) -> None:super().__init__()self.child = ChildMod()mod_unsupported = Supermod()try:script(mod_unsupported)
except NotImplementedError as e:print(f"\n捕获预期错误:{e}")

输出:

捕获预期错误:ChildMod不支持TorchScript编译

解释e3nn.util.jit正确识别"unsupported"模式并引发有用的NotImplementedError,而不是晦涩的TorchScript错误。

内部机制:e3nn.util.jit如何工作

e3nn.util.jit的魔法主要在其compile函数中(位于e3nn/util/jit.py)。此函数编排递归编译过程。

以下是e3nn.util.jit.compile如何工作的简化序列:

在这里插入图片描述

  1. 获取compile_mode:当在模块上调用e3nn.util.jit.compile时,它首先检查模块(及其类型)是否有@compile_mode装饰器设置的_E3NN_COMPILE_MODE属性。
  2. 递归步骤:如果recurse=True(这是默认且推荐的),e3nn.util.jit.compile遍历所有直接子模块。对于每个子模块,它递归调用自身。这确保首先编译内部模块。
  3. 编译:一旦递归展开(意味着当前模块的所有子模块已经编译),e3nn.util.jit.compile继续编译当前模块本身:
    • 如果设置了@compile_mode('script'),它调用torch.jit.script
    • 如果设置了@compile_mode('trace'),它首先生成示例输入(如果可用则使用_make_tracing_inputs,或通过推断irreps_in),然后调用torch.jit.trace_module
    • 如果是@compile_mode('unsupported'),它引发NotImplementedError
  4. 替换:原始Python模块被其TorchScript编译版本替换,保留在父模块结构中。

这种系统方法允许e3nn处理其等变神经网络层中存在的动态和静态行为的复杂混合,生成统一的优化TorchScript模型。

为自定义模块选择'script''trace'

当您构建自己的自定义e3nn模块并考虑如何用@compile_mode注释它们时,这里有一个简单指南:

特性/行为@compile_mode('script')@compile_mode('trace')
控制流支持数据相关if/for语句(例如,if x.mean() > 0:)。不支持数据相关控制流。仅记录示例输入所采取的路径。
Python特性Python子集有限。继承、动态类创建有问题。如果forward是静态的,与一般Python兼容性更好。
e3nn上下文适用于基于实际特征值有条件逻辑的模块。例如:e3nn.nn.Gate适用于一旦__init__Irreps已知forward图就固定的模块,即使使用复杂Irreps逻辑构建图。例如:e3nn.o3.TensorProducte3nn.o3.Linear

通常,e3nn的内置模块已经适当装饰。编写自己的模块时,如果有任何依赖于张量值的if/for语句,首先尝试@compile_mode('script')。如果遇到错误,考虑是否可以重写forward方法使其更静态,然后尝试@compile_mode('trace')。如果e3nn.util.jit本身失败,那么_make_tracing_inputs是您追踪的工具。

结论

本文介绍了如何通过TorchScript JIT编译优化e3nn等变神经网络模型的部署性能。

e3nn模型包含静态张量操作和动态控制流,给标准JIT编译带来挑战。e3nn.util.jit提供了智能解决方案:

通过@compile_mode装饰器标记模块编译方式(script/trace/unsupported)
采用递归编译策略,自动混合脚本化和追踪模式
提供script(),trace(),compile()等接口简化编译过程

  • 示例展示了如何成功编译包含数据相关控制流的e3nn模块,并验证编译前后模型输出一致。

  • 最后演示了显式混合使用追踪和脚本化模式的方法。这些技术使e3nn模型能在生产环境中高效运行。

TorchScript JIT支持e3nn中的一个强大功能,允许获取训练好的等变神经网络并优化它们以进行高性能部署。通过智能结合PyTorch的scriptingtracing功能,e3nn.util.jit无缝处理等变模型的独特复杂性

我们已经学习了e3nn如何使用@compile_mode装饰器及其scripttracecompile函数来管理此过程,包括在需要时如何提供自定义追踪输入。有了这些知识,可以确保e3nn模型不仅智能且对称,而且快速并准备好用于实际应用

这结束了我们通过e3nn核心概念的教程之旅。希望你现在可以拥有探索、构建和创新等变神经网络的坚实基础啦

END ★,°:.☆( ̄▽ ̄).°★* 。

http://www.xdnf.cn/news/18523.html

相关文章:

  • 老年常见疾病及健康管理建议
  • 精斗云智能开单解决方案:高效移动办公新体验
  • Qt/C++开发监控GB28181系统/录像文件回放/自动播放下一个录像文件/倍速回放/录像文件下载
  • openharmony之一多开发:产品形态配置讲解
  • 使用自制的NTC测量模块测试Plecs的热仿真效果
  • 分布式蜜罐系统的部署安装
  • 微服务统一入口——Gateway
  • Redis 从入门到精通:原理、实战与性能优化全解析
  • Flutter BLoC 全面入门与实战(含代码示例)
  • 云计算-K8s 运维:Python SDK 操作 Job/Deployment/Pod+RBAC 权限配置及自定义 Pod 调度器实战
  • 概率论基础教程第六章 随机变量的联合分布(一)
  • FastAPI + SQLAlchemy 数据库对象转字典
  • 解决coze api使用coze.workflows.runs.create运行workflow返回400,但text为空
  • SEO优化工具学习——Ahrefs进行关键词调研(包含实战)
  • 市政道路井盖缺失识别误报率↓82%!陌讯多模态融合算法实战优化与边缘部署
  • ChipCamp探索系列 -- 4. Intel CPU的十八代微架构
  • 【React Native】自定义轮盘(大转盘)组件Wheel
  • 【KO】前端面试题四
  • 今日科技热点 | 量子计算突破、AI芯片与5G加速行业变革
  • PLECS 中使用 C-Script 来模拟 NTC 热敏电阻(如 NTC3950B)
  • 【K8s】整体认识K8s之Docker篇
  • 百度面试题:赛马问题
  • 嵌入式LINUX-------------数据库
  • 循环中的阻塞风险与异步线程解法
  • 搜索体验优化:ABP vNext 的查询改写(Query Rewrite)与同义词治理
  • 前端安全之XSS和CSRF
  • 鸿蒙异步处理从入门到实战:Promise、async/await、并发池、超时重试全套攻略
  • 互联网大厂Java面试实战:核心技术栈与场景化提问解析(含Spring Boot、微服务、测试框架等)
  • 量子计算驱动的Python医疗诊断编程前沿展望(中)
  • RabbitMQ面试精讲 Day 28:Docker与Kubernetes部署实践