PyTorch API 5
文章目录
- torch.compiler
- 延伸阅读
- torch.fft
- 快速傅里叶变换
- 辅助函数
- torch.func
- 什么是可组合的函数变换?
- 为什么需要可组合的函数变换?
- 延伸阅读
- torch.futures
- torch.fx
- 概述
- 编写转换函数
- 图结构快速入门
- 图操作
- 直接操作计算图
- 使用 replace_pattern() 进行子图重写
- 图操作示例
- 代理/回溯机制
- 解释器模式
- 解释器模式示例
- 调试
- 简介
- 变换编写中的常见陷阱
- 检查模块的正确性
- 调试生成的代码
- 使用 `pdb`
- 打印生成的代码
- 使用 `GraphModule` 中的 `to_folder` 函数
- 调试转换过程
- 可用的调试器
- 符号追踪的局限性
- 动态控制流
- 静态控制流
- 非`torch`函数
- 使用 `Tracer` 类自定义追踪功能
- 叶子模块
- 杂项说明
- API 参考
- torch.fx.experimental
- torch.fx.experimental.symbolic_shapes
- torch.fx.experimental.proxy_tensor
- torch.hub
- 发布模型
- 如何实现入口点?
- 重要通知
- 从Hub加载模型
- 运行加载的模型:
- 下载的模型保存在哪里?
- 缓存逻辑
- 已知限制:
- TorchScript
- 创建 TorchScript 代码
- 混合使用追踪与脚本化
- TorchScript 语言
- 内置函数与模块
- PyTorch 函数与模块
- Python 函数与模块
- Python 语言参考对比
- 调试
- 禁用 JIT 进行调试
- 代码检查
- 解读图结构
- 追踪器
- 追踪边界情况
- 自动追踪检查
- 追踪器警告
- 常见问题解答
- 已知问题
- 附录
- 迁移至 PyTorch 1.2 递归脚本化 API
- 模块
- 函数
- TorchScript 类
- 属性
- 常量
- 变量
- 融合后端
- 参考资料
- torch.linalg
- 矩阵属性
- 矩阵分解
- 求解器
- 逆矩阵
- 矩阵函数
- 矩阵运算
- 张量运算
- 杂项函数
- 实验性函数
- torch.monitor
- API 参考
- torch.signal 模块
- torch.signal.windows 窗口函数
- torch.special
- 函数
- torch.overrides
- 函数
- torch.package
- 教程
- 打包你的第一个模型
- 如何实现...
- 查看包内包含哪些内容?
- 将包视为ZIP归档文件处理
- 使用 `file_structure()` API
- 查看某个模块为何被列为依赖项?
- 如何在打包时包含任意资源并后续访问?
- 自定义类的打包方式
- 如何在源码中检测当前是否运行在包环境中?
- 如何将代码补丁打入包中?
- 如何从打包代码中访问包内容?
- 区分打包代码与非打包代码
- 如何重新导出已导入的对象?
- 如何打包 TorchScript 模块?
- 说明
- `torch.package` 格式概述
- 框架文件
- 用户文件
- `torch.package` 如何查找代码依赖项
- 分析对象的依赖关系
- 分析模块依赖关系
- 依赖管理
- `intern`
- `extern`
- `mock`
- 代码重构
- 模式
- `torch.package` 的注意事项
- 避免在模块中使用全局状态
- 类型在包与加载环境之间不共享
- `torch.package` 如何实现包之间的隔离
- 名称修饰(Mangling)
- API 参考
- torch.profiler
- 概述
- API 参考
- Intel 插桩与追踪技术 API
- torch.nn.init
- torch.nn.attention
- 工具集
- 子模块
torch.compiler
torch.compiler
是一个命名空间,通过它向用户开放了一些内部编译器方法。该命名空间中的主要功能和特性是 torch.compile
。
torch.compile
是 PyTorch 2.x 引入的一个函数,旨在解决 PyTorch 中精确图捕获的问题,最终帮助软件工程师加速运行他们的 PyTorch 程序。torch.compile
使用 Python 编写,标志着 PyTorch 从 C++ 向 Python 的过渡。
torch.compile
利用了以下底层技术:
- TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的 Frame Evaluation API 功能来安全捕获 PyTorch 计算图。通过
torch.compiler
命名空间向 PyTorch 用户开放可用方法。 - TorchInductor 是
torch.compile
默认的深度学习编译器,为多种加速器和后端生成快速代码。需要通过后端编译器才能实现torch.compile
的加速效果。对于 NVIDIA、AMD 和 Intel GPU,它使用 OpenAI Triton 作为关键构建块。 - AOT Autograd 不仅能捕获用户级代码,还能捕获反向传播,实现"提前"捕获反向传递。这使得 TorchInductor 能够同时加速前向和反向传递。
注意:在本文档中,术语 torch.compile
、TorchDynamo 和 torch.compiler
有时会互换使用。
如上所述,要通过 TorchDynamo 运行更快的工作流,torch.compile
需要一个后端将捕获的计算图转换为快速机器码。不同的后端会带来不同的优化效果。默认后端是 TorchInductor(也称为 inductor)。TorchDynamo 还支持由合作伙伴开发的一系列后端,可以通过运行 torch.compiler.list_backends()
查看,每个后端都有其可选依赖项。
一些最常用的后端包括:
训练和推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="inductor") | 使用 TorchInductor 后端。了解更多 |
torch.compile(m, backend="cudagraphs") | 使用 AOT Autograd 的 CUDA 图。了解更多 |
torch.compile(m, backend="ipex") | 在 CPU 上使用 IPEX。了解更多 |
torch.compile(m, backend="onnxrt") | 使用 ONNX Runtime 在 CPU/GPU 上进行训练。了解更多 |
仅推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="tensorrt") | 使用 Torch-TensorRT 进行推理优化。需要在调用脚本中 import torch_tensorrt 来注册后端。了解更多 |
torch.compile(m, backend="ipex") | 在 CPU 上使用 IPEX 进行推理。了解更多 |
torch.compile(m, backend="tvm") | 使用 Apache TVM 进行推理优化。了解更多 |
torch.compile(m, backend="openvino") | 使用 OpenVINO 进行推理优化。了解更多 |
延伸阅读
PyTorch 用户入门指南
- 快速入门
- torch.compiler API 参考
- torch.compiler.config 配置
- TorchDynamo 细粒度追踪 API
- AOTInductor: Torch.Export 模型的预编译方案
- TorchInductor GPU 性能分析
- torch.compile 性能剖析指南
- 常见问题解答
- torch.compile 故障排查
- PyTorch 2.0 性能看板
PyTorch 开发者深度解析
- Dynamo 架构概览
- Dynamo 技术深潜
- 动态形状支持
- PyTorch 2.0 NNModule 支持
- 后端开发最佳实践
- CUDA 图树优化
- 伪张量机制
PyTorch 后端供应商指南
- 自定义后端开发
- ATen IR 图转换开发
- 中间表示层详解
torch.fft
离散傅里叶变换及相关函数。
快速傅里叶变换
fft | 计算input 的一维离散傅里叶变换 |
---|---|
ifft | 计算input 的一维离散傅里叶逆变换 |
fft2 | 计算input 的二维离散傅里叶变换 |
ifft2 | 计算input 的二维离散傅里叶逆变换 |
fftn | 计算input 的N维离散傅里叶变换 |
ifftn | 计算input 的N维离散傅里叶逆变换 |
rfft | 计算实数input 的一维傅里叶变换 |
irfft | 计算rfft() 的逆变换 |
rfft2 | 计算实数input 的二维离散傅里叶变换 |
irfft2 | 计算rfft2() 的逆变换 |
rfftn | 计算实数input 的N维离散傅里叶变换 |
irfftn | 计算rfftn() 的逆变换 |
hfft | 计算Hermitian对称input 信号的一维离散傅里叶变换 |
ihfft | 计算hfft() 的逆变换 |
hfft2 | 计算Hermitian对称input 信号的二维离散傅里叶变换 |
ihfft2 | 计算实数input 的二维离散傅里叶逆变换 |
hfftn | 计算Hermitian对称input 信号的N维离散傅里叶变换 |
ihfftn | 计算实数input 的N维离散傅里叶逆变换 |
辅助函数
fftfreq | 计算大小为 n 的信号的离散傅里叶变换采样频率。 |
---|---|
rfftfreq | 计算大小为 n 的信号在使用 rfft() 时的采样频率。 |
fftshift | 对由 fftn() 提供的 n 维 FFT 数据进行重新排序,使负频率项优先。 |
ifftshift | fftshift() 的逆操作。 |
torch.func
torch.func(前身为"functorch")是为PyTorch提供的JAX风格可组合函数变换工具。
注意:该库目前处于测试阶段。
这意味着这些功能基本可用(除非另有说明),且我们(PyTorch团队)将持续推进该库的发展。但API可能会根据用户反馈进行调整,且尚未完全覆盖所有PyTorch操作。
如果您对API有改进建议,或希望支持特定使用场景,请提交GitHub issue或直接联系我们。我们非常期待了解您如何使用这个库。
什么是可组合的函数变换?
- 函数变换是一种高阶函数,它接受一个数值函数作为输入,并返回一个新函数来计算不同的量。
torch.func
提供了自动微分变换(例如grad(f)
返回计算f
梯度的函数)、向量化/批处理变换(例如vmap(f)
返回对输入批次执行f
的函数)等多种变换。- 这些函数变换可以任意组合使用。例如,组合
vmap(grad(f))
可以计算单样本梯度(per-sample-gradients),这是当前标准 PyTorch 无法高效计算的量。
为什么需要可组合的函数变换?
目前在 PyTorch 中实现以下用例较为棘手:
- 计算逐样本梯度(或其他逐样本量)
- 在单台机器上运行模型集成
- 在 MAML 内循环中高效批处理任务
- 高效计算雅可比矩阵和海森矩阵
- 高效计算批量雅可比矩阵和海森矩阵
通过组合使用 vmap()
、grad()
和 vjp()
变换,我们无需为每个用例单独设计子系统即可实现上述功能。这种可组合函数变换的理念源自 JAX 框架。
延伸阅读
- torch.func 快速指南
- 什么是 torch.func?
- 为什么需要可组合函数变换?
- 有哪些变换方法?
- torch.func API 参考
- 函数变换
- torch.nn.Module 工具集
- 调试工具
- 使用限制
- 通用限制
- torch.autograd API
- vmap 限制
- 随机性控制
- 从 functorch 迁移到 torch.func
- 函数变换
- 神经网络模块工具
- functorch.compile
torch.futures
该包提供了一种 Future
类型,用于封装异步执行过程,并提供一组实用函数来简化对 Future
对象的操作。目前,Future
类型主要被 分布式RPC框架 使用。
class torch.futures.Future(*, devices=None)
Wrapper around a torch._C.Future
which encapsulates an asynchronous
execution of a callable, e.g. rpc_async()
. It also exposes a set of APIs to add callback functions and set results.
Warning: GPU support is a beta feature, subject to changes.
add_done_callback(callback)
将给定的回调函数附加到此Future
上,该回调函数将在Future
完成时运行。可以向同一个Future
添加多个回调,但无法保证它们的执行顺序。回调函数必须接受一个参数,即对此Future
的引用。回调函数可以使用value()
方法获取值。请注意,如果此Future
已经完成,给定的回调将立即内联执行。
我们建议使用then()
方法,因为它提供了一种在回调完成后进行同步的方式。如果回调不返回任何内容,add_done_callback
可能更高效。但then()
和add_done_callback
在底层使用相同的回调注册API。
对于GPU张量,此方法的行为与then()
相同。
参数
callback (
Future)
– 一个可调用对象,接受一个参数,即对此Future
的引用。
注意:请注意,如果回调函数抛出异常,无论是由于原始future以异常完成并调用fut.wait()
,还是由于回调中的其他代码,都必须仔细处理错误。例如,如果此回调随后完成了其他future,这些future不会被标记为以错误完成,用户需要独立处理这些future的完成/等待。
示例:
>>> def callback(fut):
... print("This will run after the future has finished.")
... print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
This will run after the future has finished.
5
done()
如果该Future
已完成则返回True
。当Future
包含结果或异常时即视为完成。
如果值包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,Future.done()
仍会返回True
,因为在此阶段结果已可被使用(前提是执行适当的同步操作,参见wait()
)。
返回类型:bool
set_exception(result)
为这个 Future
设置一个异常,这将标记该 Future
以错误状态完成,并触发所有已附加的回调。请注意,当对此 Future
调用 wait()/value() 时,此处设置的异常
将被内联抛出。
参数
result ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException "(in Python v3.13)"))
– 该Future
的异常对象。
示例
>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
Traceback (most recent call last):
...
ValueError: foo
set_result(result)
为这个Future
设置结果,这将标记该Future
为已完成状态并触发所有关联的回调。需要注意的是,一个Future
不能被标记为已完成两次。
如果结果包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,只要调用此方法时这些内核所入队的流被设置为当前流,仍可调用此方法。简而言之,在启动这些内核后立即调用此方法是安全的,无需额外同步,前提是期间不切换流。此方法会在所有相关当前流上记录事件,并利用它们确保此Future
的所有消费者都能得到正确调度。
参数
result ( object )
- 该Future
的结果对象。
示例:
>>> import threading
>>> import time
>>> def slow_set_future(fut, value):
... time.sleep(0.5)
... fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
... target=slow_set_future,
... args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()
then(callback)
将给定的回调函数附加到此Future
上,该回调函数将在Future
完成时运行。可以向同一个Future
添加多个回调,但无法保证它们的执行顺序(如需确保特定顺序,请考虑链式调用:fut.then(cb1).then(cb2)
)。回调函数必须接受一个参数,即对此Future
的引用。回调函数可通过value()
方法获取值。请注意,如果此Future
已完成,给定的回调将立即内联执行。
如果Future
的值包含位于GPU上的张量,回调可能在填充这些张量的异步内核尚未在设备上完成执行时就被调用。不过,回调将通过设置为当前的一些专用流(从全局池中获取)被调用,这些流将与那些内核同步。因此,回调对这些张量执行的任何操作都将在内核完成后调度到设备上。换句话说,只要回调不切换流,它就可以安全地操作结果而无需额外同步。这与wait()
的非阻塞行为类似。
类似地,如果回调返回的值包含位于GPU上的张量,即使生成这些张量的内核仍在设备上运行,回调也可以这样做,前提是回调在执行期间没有切换流。如果想要切换流,必须注意与原始流重新同步,即回调被调用时当前的流。
参数
callback (
Callable)
– 一个以该Future
为唯一参数的可调用对象。
返回
一个新的Future
对象,它持有callback
的返回值,并将在给定callback
完成时标记为已完成。
返回类型
Future[S]
注意:请注意,如果回调函数抛出异常,无论是通过原始future以异常完成并调用fut.wait()
,还是通过回调中的其他代码,then
返回的future将适当地标记为遇到错误。但是,如果此回调随后完成其他future,这些future不会标记为以错误完成,用户需负责独立处理这些future的完成/等待。
示例:
>>> def callback(fut):
... print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
... lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
RPC return value is 5、Chained cb done. None
value()
获取已完成的Future对象的值。
此方法仅应在调用wait()
完成后,或在传递给then()
的回调函数内部使用。其他情况下,该Future
可能尚未持有值,调用value()
可能会失败。
如果值包含位于GPU上的张量,此方法将不会执行任何额外的同步操作。此类同步应事先通过调用wait()
单独完成(回调函数内部除外,因为then()
已自动处理此情况)。
返回值
该Future
持有的值。如果创建该值的函数(回调或RPC)抛出错误,此value()
方法同样会抛出错误。
返回类型:T
wait()
等待直到该 Future
的值准备就绪。
如果值包含位于 GPU 上的张量,则会与设备上异步填充这些张量的内核执行额外的同步操作。此类同步是非阻塞的,这意味着 wait()
会在当前流中插入必要的指令,以确保后续在这些流上排队的操作能正确安排在异步内核之后执行。但一旦完成指令插入,即使这些内核仍在运行,wait()
也会立即返回。只要不切换流,在访问和使用这些值时无需进一步同步。
返回值:此 Future
持有的值。如果创建该值的函数(回调或 RPC)抛出错误,此 wait
方法同样会抛出错误。
返回类型:T
torch.futures.collect_all(futures)
将提供的 Future
对象收集到一个统一的组合 Future
中,该组合 Future 会在所有子 Future 完成时完成。
参数
futures (list)
– 一个包含Future
对象的列表。
返回
返回一个 Future
对象,该对象关联到传入的 Future 列表。
返回类型
Future[list [torch.jit.Future]]
示例
>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1])
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
fut1 result = 1
torch.futures.wait_all(futures)
等待所有提供的 futures 完成,并返回已完成值的列表。如果任一 future 遇到错误,该方法将提前退出并报告错误,而不会等待其他 futures 完成。
参数
futures (list)
– 一个Future
对象列表。
返回值:已完成 Future
结果的列表。如果对任何 Future
调用 wait
时抛出错误,该方法也会抛出错误。
返回类型:list
torch.fx
概述
FX 是一个供开发者使用的工具包,用于转换 nn.Module
实例。FX 包含三个核心组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件的实际应用演示:
import torch# Simple module for demonstration
class MyModule(torch.nn.Module):def __init__(self) -None:super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():%x : [num_users=1] = placeholder[target=x]%param : [num_users=1] = get_attr[target=param]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param; x = param = Nonelinear = self.linear(add); add = Noneclamp = linear.clamp(min = 0.0, max = 1.0); linear = Nonereturn clamp
"""
符号追踪器(symbolic tracer)对Python代码执行"符号执行"。它通过代码传递称为Proxy的虚拟值,并记录对这些Proxy的操作。有关符号追踪的更多信息,请参阅symbolic_trace()
和Tracer
文档。
中间表示(intermediate representation)是符号追踪过程中记录操作的容器。它由一组节点组成,这些节点表示函数输入、调用点(指向函数、方法或torch.nn.Module
实例)以及返回值。有关IR的更多信息,请参阅Graph
文档。IR是应用转换的基础格式。
Python代码生成功能使FX成为Python到Python(或Module到Module)的转换工具包。对于每个Graph IR,我们都可以生成符合Graph语义的有效Python代码。这个功能被封装在GraphModule
中,它是一个torch.nn.Module
实例,包含一个Graph
以及从Graph生成的forward
方法。
这些组件(符号追踪→中间表示→转换→Python代码生成)共同构成了FX的Python到Python转换流程。此外,这些组件也可以单独使用。例如,符号追踪可以单独用于捕获代码形式进行分析(而非转换)目的。代码生成可以用于通过编程方式生成模型,例如从配置文件生成。FX有许多用途!
在示例库中可以找到几个转换示例。
编写转换函数
什么是FX转换?本质上,它是一个形如下列的函数。
import torch
import torch.fxdef transform(m: nn.Module, tracer_class : type = torch.fx.Tracer) -torch.nn.Module:# Step 1: Acquire a Graph representing the code in `m`# NOTE: torch.fx.symbolic_trace is a wrapper around a call to # fx.Tracer.trace and constructing a GraphModule. We'll# split that out in our transform to allow the caller to # customize tracing behavior.graph : torch.fx.Graph = tracer_class().trace(m)# Step 2: Modify this Graph or create a new onegraph = ...# Step 3: Construct a Module to returnreturn torch.fx.GraphModule(m, graph)
您的转换器将接收一个 torch.nn.Module
,从中获取 Graph
,进行一些修改后返回一个新的 torch.nn.Module
。您应该将 FX 转换器返回的 torch.nn.Module
视为与常规 torch.nn.Module
完全相同——可以将其传递给另一个 FX 转换器、传递给 TorchScript 或直接运行它。确保 FX 转换器的输入和输出均为 torch.nn.Module
将有助于实现组合性。
注意:也可以直接修改现有的 GraphModule
而不创建新实例,例如:
import torch
import torch.fxdef transform(m : nn.Module) -nn.Module:gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)# Modify gm.graph# <...># Recompile the forward() method of `gm` from its Graphgm.recompile()return gm
请注意,你必须调用 GraphModule.recompile()
方法,使生成的 forward()
方法与修改后的 Graph
保持同步。
假设你已经传入了一个经过追踪转换为 Graph
的 torch.nn.Module
,现在主要有两种方法来构建新的 Graph
。
图结构快速入门
关于图的语义完整说明可以参考 Graph
文档,这里我们主要介绍基础概念。Graph
是一种数据结构,用于表示 GraphModule
上的方法。其核心需要描述以下信息:
- 方法的输入参数是什么?
- 方法内部运行了哪些操作?
- 方法的输出(即返回值)是什么?
这三个概念都通过 Node
实例来表示。下面通过一个简单示例来说明:
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)gm.graph.print_tabular()
这里我们定义一个演示用的模块 MyModule
,实例化后进行符号追踪,然后调用 Graph.print_tabular()
方法打印该 Graph
的节点表格:
操作码 | 名称 | 目标 | 参数 | 关键字参数 |
---|---|---|---|---|
placeholder | x | x | () | {} |
get_attr | linear_weight | linear.weight | () | {} |
call_function | add_1 | <built-in function add | (x, linear_weight) | {} |
call_module | linear_1 | linear | (add_1,) | {} |
call_method | relu_1 | relu | (linear_1,) | {} |
call_function | sum_1 | <built-in method sum … | (relu_1,) | {‘dim’: -1} |
call_function | topk_1 | <built-in method topk … | (sum_1, 3) | {} |
output | output | output | (topk_1,) | {} |
通过这些信息,我们可以回答之前提出的问题:
- 方法的输入是什么?
在FX中,方法输入通过特殊的placeholder
节点指定。本例中有一个目标为x
的placeholder
节点,表示存在一个名为x的(非self)参数。 - 方法内部有哪些操作?
get_attr
、call_function
、call_module
和call_method
节点表示方法中的操作。这些节点的完整语义说明可参考Node
文档。 - 方法的返回值是什么?
在Graph
中,返回值由特殊的output
节点指定。
现在我们已经了解FX中代码表示的基本原理,接下来可以探索如何编辑 Graph
。
图操作
直接操作计算图
构建新Graph
的一种方法是直接操作原有计算图。为此,我们可以简单地获取通过符号追踪得到的Graph
并进行修改。例如,假设我们需要将所有torch.add()
调用替换为torch.mul()
调用。
import torch
import torch.fx# Sample module
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)def transform(m: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph : fx.Graph = tracer_class().trace(m)# FX represents its Graph as an ordered list of # nodes, so we can iterate through them.for node in graph.nodes:# Checks if we're calling a function (i.e:# torch.add)if node.op == 'call_function':# The target attribute is the function# that call_function calls.if node.target == torch.add:node.target = torch.mulgraph.lint() # Does some checks to make sure the # Graph is well-formed.return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph
重写操作,例如删除或追加节点。为了辅助这些转换,FX 提供了一些用于操作计算图的实用函数,这些函数可以在 Graph
文档中找到。
下面展示了一个使用这些 API 追加 torch.relu()
调用的示例。
# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node):# Insert a new `call_function` node calling `torch.relu`new_node = traced.graph.call_function(torch.relu, args=(node,))# We want all places that used the value of `node` to # now use that value after the `relu` call we've added.# We use the `replace_all_uses_with` API to do this.node.replace_all_uses_with(new_node)
对于仅包含替换操作的简单转换,您也可以使用子图重写器。
使用 replace_pattern() 进行子图重写
FX 在直接图操作的基础上提供了更高层次的自动化能力。replace_pattern()
API 本质上是一个用于编辑 Graph
的"查找/替换"工具。它允许你指定一个 pattern
(模式)和 replacement
(替换)函数,然后会追踪这些函数,在图中找到与 pattern
图匹配的操作组实例,并用 replacement
图的副本替换这些实例。这可以极大地自动化繁琐的图操作代码,随着转换逻辑变得复杂,手动操作会变得难以维护。
图操作示例
- 替换单个操作符
- 卷积/批量归一化融合
- replace_pattern:基础用法
- 量化
- 逆变换
代理/回溯机制
另一种操作 Graph
的方式是复用符号追踪中使用的 Proxy
机制。例如,假设我们需要编写一个将 PyTorch 函数分解为更小操作的转换器:将每个 F.relu(x)
调用转换为 (x > 0) * x
。传统做法可能是通过图重写来插入比较和乘法操作,然后清理原始的 F.relu
。但借助 Proxy
对象,我们可以自动将操作记录到 Graph
中来实现这一过程。
具体实现时,只需将需要插入的操作写成常规 PyTorch 代码,并用 Proxy
对象作为参数调用该代码。这些 Proxy
对象会捕获对其执行的操作,并将其追加到 Graph
中。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):return (x 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose(model: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:"""Decompose `model` into smaller constituent operations.Currently,this only supports decomposing ReLU into itsmathematical definition: (x 0) * x"""graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()env = {}tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)for node in graph.nodes:if node.op == 'call_function' and node.target in decomposition_rules:# By wrapping the arguments with proxies, # we can dispatch to the appropriate# decomposition rule and implicitly add it# to the Graph by symbolically tracing it.proxy_args = [fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]output_proxy = decomposition_rules[node.target](proxy_args)# Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception.# We need to extract the underlying `Node` from the `Proxy`# to use it in subsequent iterations of this transform.new_node = output_proxy.nodeenv[node.name] = new_nodeelse:# Default case: we don't have a decomposition rule for this # node, so just copy the node over into the new graph.new_node = new_graph.node_copy(node, lambda x: env[x.name])env[node.name] = new_nodereturn fx.GraphModule(model, new_graph)
除了避免显式的图操作外,使用Proxy
还允许您将重写规则指定为原生Python代码。对于需要大量重写规则的转换(如vmap或grad),这通常可以提高规则的可读性和可维护性。
需要注意的是,在调用Proxy
时,我们还传递了一个指向底层变量图的追踪器。这样做是为了防止当图中的操作是n元操作时(例如add是二元运算符),调用Proxy
不会创建多个图追踪器实例,否则可能导致意外的运行时错误。特别是在底层操作不能安全地假设为一元操作时,我们推荐使用这种Proxy
方法。
一个使用Proxy
进行Graph
操作的实际示例可以在这里找到。
解释器模式
在FX中,一个实用的代码组织模式是遍历Graph
中的所有Node
并执行它们。这种模式可用于多种场景,包括:
- 运行时分析流经计算图的值
- 通过
Proxy
重新追踪来实现代码转换
例如,假设我们想运行一个GraphModule
,并在运行时记录节点上torch.Tensor
的形状和数据类型属性。实现代码可能如下:
import torch
import torch.fx
from torch.fx.node import Nodefrom typing import Dictclass ShapeProp:"""Shape propagation. This class takes a `GraphModule`.Then, its `propagate` method executes the `GraphModule`node-by-node with the given arguments. As each operationexecutes, the ShapeProp class stores away the shape and element type for the output values of each operation on the `shape` and `dtype` attributes of the operation's`Node`."""def __init__(self, mod):self.mod = modself.graph = mod.graphself.modules = dict(self.mod.named_modules())def propagate(self, args):args_iter = iter(args)env : Dict[str, Node] = {}def load_arg(a):return torch.fx.graph.map_arg(a, lambda n: env[n.name])def fetch_attr(target : str):target_atoms = target.split('.')attr_itr = self.modfor i, atom in enumerate(target_atoms):if not hasattr(attr_itr, atom):raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")attr_itr = getattr(attr_itr, atom)return attr_itrfor node in self.graph.nodes:if node.op == 'placeholder':result = next(args_iter)elif node.op == 'get_attr':result = fetch_attr(node.target)elif node.op == 'call_function':result = node.target(load_arg(node.args), *load_arg(node.kwargs))elif node.op == 'call_method':self_obj, args = load_arg(node.args)kwargs = load_arg(node.kwargs)result = getattr(self_obj, node.target)(args, *kwargs)elif node.op == 'call_module':result = self.modules[node.target](load_arg(node.args), *load_arg(node.kwargs))# This is the only code specific to shape propagation.# you can delete this `if` branch and this becomes# a generic GraphModule interpreter.if isinstance(result, torch.Tensor):node.shape = result.shapenode.dtype = result.dtypeenv[node.name] = resultreturn load_arg(self.graph.result)
如你所见,为FX实现一个完整的解释器并不复杂,但却非常实用。为了简化这一模式的使用,我们提供了Interpreter
类,它封装了上述逻辑,允许通过方法重写来覆盖解释器执行的某些方面。
除了执行操作外,我们还可以通过向解释器传递Proxy
值来生成新的计算图。
类似地,我们提供了Transformer
类来封装这种模式。Transformer
的行为与Interpreter
类似,但不同于调用run
方法从模块获取具体输出值,你需要调用Transformer.transform()
方法来返回一个新的GraphModule
,该模块会应用你通过重写方法设置的任何转换规则。
解释器模式示例
- 形状传播
- 性能分析器
调试
简介
在编写转换代码的过程中,我们的代码往往不会一开始就完全正确。这时就需要进行调试。关键在于采用逆向思维:首先检查调用生成模块的结果,验证其正确性;接着审查并调试生成的代码;最后追溯导致生成代码的转换过程并进行调试。
如果您不熟悉调试工具,请参阅辅助章节可用调试工具。
变换编写中的常见陷阱
set
迭代顺序的不确定性。在Python中,set
数据类型是无序的。例如,使用set
来存储Node
等对象集合可能导致意外的非确定性行为。比如当迭代一组Node
并将其插入Graph
时,由于set
数据类型是无序的,输出程序中操作的顺序将是非确定性的,且每次程序调用都可能变化。
推荐的替代方案是使用dict
数据类型。自Python 3.7起(以及cPython 3.6起),dict
保持了插入顺序。通过将需要去重的值存储在dict
的键中,可以等效地实现set
的功能。
检查模块的正确性
由于大多数深度学习模块的输出都是浮点型 torch.Tensor
实例,因此检查两个 torch.nn.Module
的结果是否相等并不像简单的相等性检查那样直接。为了说明这一点,我们来看一个示例:
import torch
import torch.fx
import torchvision.models as modelsdef transform(m : torch.nn.Module) -torch.nn.Module:gm = torch.fx.symbolic_trace(m)# Imagine we're doing some transforms here# <...>gm.recompile()return gmresnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)input_image = torch.randn(5, 3, 224, 224)assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用==
相等运算符来检查两个深度学习模型的值是否相等。然而,这种做法存在两个问题:首先,该运算符返回的是张量而非布尔值;其次,浮点数值的比较应考虑误差范围(或epsilon),以解决浮点运算不可交换性的问题(详见此处)。
我们可以改用torch.allclose()
函数,它会基于相对和绝对容差阈值进行近似比较:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中的第一个工具,用于检查转换后的模块与参考实现相比是否按预期运行。
调试生成的代码
由于 FX 在 GraphModule
上生成 forward()
函数,使用传统的调试技术(如 print
语句或 pdb
)会不太直观。幸运的是,我们有多种方法可以用来调试生成的代码。
使用 pdb
通过调用 pdb
可以进入正在运行的程序进行调试。虽然表示 Graph
的代码不在任何源文件中,但当执行前向传播时,我们仍然可以手动使用 pdb
进入该代码进行调试。
import torch
import torch.fx
import torchvision.models as modelsdef my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph = tracer_class().trace(inp)# Transformation logic here# <...># Return new Modulereturn fx.GraphModule(inp, graph)my_module = models.resnet18()
my_module_transformed = my_pass(my_module)input_value = torch.randn(5, 3, 224, 224)# When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line
import pdb; pdb.set_trace()my_module_transformed(input_value)
打印生成的代码
如果需要多次运行相同的代码,使用pdb
逐步调试到目标代码可能会有些繁琐。这种情况下,一个简单的方法是将生成的forward
传递代码直接复制粘贴到你的代码中,然后在那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):x = self.xadd_1 = x + y; x = y = Nonereturn add_1
"""# Subclass the original Module
class SubclassM(M):def __init__(self):super().__init__()# Paste the generated `forward` function (the one we printed and # copied above) heredef forward(self, y):x = self.xadd_1 = x + y; x = y = Nonereturn add_1# Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
中的 to_folder
函数
GraphModule.to_folder()
是 GraphModule
中的一个方法,它允许你将生成的 FX 代码导出到一个文件夹。虽然像打印生成的代码中那样直接复制前向传播代码通常已经足够,但使用 to_folder
可以更方便地检查模块和参数。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上述示例后,我们可以查看foo/module.py
中的代码,并根据需要进行修改(例如添加print
语句或使用pdb
)来调试生成的代码。
调试转换过程
既然我们已经确认是转换过程生成了错误代码,现在就该调试转换本身了。首先,我们会查阅文档中的符号追踪限制部分。在确认追踪功能按预期工作后,我们的目标就转变为找出GraphModule
转换过程中出现的问题。编写转换部分可能有快速解决方案,如果没有的话,我们还可以通过多种方式来检查追踪模块:
# Sample Module
class M(torch.nn.Module):def forward(self, x, y):return x + y# Create an instance of `M`
m = M()# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity.
traced = symbolic_trace(m)# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):add = x + y; x = y = Nonereturn add
"""# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():%x : [num_users=1] = placeholder[target=x]%y : [num_users=1] = placeholder[target=y]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})return add
"""# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add (x, y) {}
output output output (add,) {}
"""
通过使用上述工具函数,我们可以对比应用转换前后的追踪模块。有时,简单的视觉对比就足以定位错误。如果问题仍不明确,下一步可以尝试使用 pdb
这类调试器。
以上述示例为基础,请看以下代码:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:# Get the Graph from our traced Moduleg = tracer_class().trace(module)"""Transformations on `g` go here"""return fx.GraphModule(module, g)# Transform the Graph
transformed = transform_graph(traced)# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
以上述示例为例,假设调用print(traced)
时发现转换过程中存在错误。我们需要通过调试器定位问题根源。启动pdb
调试会话后,可以在transform_graph(traced)
处设置断点,然后按s
键"步入"该函数调用,实时观察转换过程。
另一个有效方法是修改print_tabular
方法,使其输出图中节点的不同属性(例如查看节点的input_nodes
和users
关系)。
可用的调试器
最常用的Python调试器是pdb。你可以通过在命令行输入python -m pdb FILENAME.py
来以"调试模式"启动程序,其中FILENAME
是你要调试的文件名。之后,你可以使用pdb
的调试器命令逐步执行正在运行的程序。通常的做法是在启动pdb
时设置一个断点(b LINE-NUMBER
),然后调用c
让程序运行到该断点处。这样可以避免你不得不使用s
或n
逐行执行代码才能到达想要检查的部分。或者,你也可以在想中断的代码行前写入import pdb; pdb.set_trace()
。如果添加了pdb.set_trace()
,当你运行程序时它会自动进入调试模式(换句话说,你只需在命令行输入python FILENAME.py
而不用输入python -m pdb FILENAME.py
)。一旦以调试模式运行文件,你就可以使用特定命令逐步执行代码并检查程序的内部状态。网上有很多关于pdb
的优秀教程,包括RealPython的《Python Debugging With Pdb》。
像PyCharm或VSCode这样的IDE通常内置了调试器。在你的IDE中,你可以选择:a)通过调出IDE中的终端窗口(例如在VSCode中选择View → Terminal)使用pdb
,或者b)使用内置的调试器(通常是pdb
的图形化封装)。
符号追踪的局限性
FX 采用符号追踪(又称符号执行)系统,以可转换/可分析的形式捕获程序语义。该系统具有以下特点:
- 追踪性:通过实际执行程序(实际是
torch.nn.Module
或函数)来记录操作 - 符号性:执行过程中流经程序的数据并非真实数据,而是符号(FX术语中称为
Proxy
)
虽然符号追踪适用于大多数神经网络代码,但它仍存在一些局限性。
动态控制流
符号追踪的主要局限在于目前不支持动态控制流。也就是说,当循环或if
语句的条件可能依赖于程序输入值时,就无法处理。
例如,我们来看以下程序:
def func_to_trace(x):if x.sum() 0:return torch.relu(x)else:return torch.neg(x)traced = torch.fx.symbolic_trace(func_to_trace)
"""<...>File "dyn.py", line 6, in func_to_traceif x.sum() 0:File "pytorch/torch/fx/proxy.py", line 155, in __bool__return self.tracer.to_bool(self)File "pytorch/torch/fx/proxy.py", line 85, in to_boolraise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于x.sum()
的值,而该值又依赖于函数输入x
。由于x
可能发生变化(例如向追踪函数传入新的输入张量时),这就形成了动态控制流。回溯信息会沿着代码向上追溯,展示这种情况发生的位置。
静态控制流
另一方面,系统支持所谓的静态控制流。静态控制流指的是那些在多次调用中值不会改变的循环或if
语句。通常在PyTorch程序中,这种控制流出现在根据超参数决定模型架构的代码中。举个具体例子:
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# This if-statement is so-called static control flow.# Its condition does not depend on any input valuesif self.do_activation:x = torch.relu(x)return xwithout_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):linear_1 = self.linear(x); x = Nonereturn linear_1
"""traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):linear_1 = self.linear(x); x = Nonerelu_1 = torch.relu(linear_1); linear_1 = Nonereturn relu_1
"""
if self.do_activation
这个条件语句不依赖于任何函数输入,因此它是静态的。do_activation
可以被视为一个超参数,当 MyModule
的不同实例使用不同参数值时,生成的代码轨迹也会不同。这是一种有效模式,符号追踪功能支持这种模式。
许多动态控制流的实例在语义上其实是静态控制流。通过消除对输入值的数据依赖,这些实例可以支持符号追踪。具体方法包括:
- 将值移至
Module
属性中 - 在符号追踪期间将具体值绑定到参数上
def f(x, flag):if flag: return xelse: return x*2fx.symbolic_trace(f) # Fails!fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含此类代码的程序部分可以被追踪为对方法的调用(参见使用Tracer类自定义追踪)或函数调用(参见wrap()
),而不是直接追踪这些代码本身。
非torch
函数
FX采用__torch_function__
作为拦截调用的机制(更多技术细节请参阅技术概览)。某些函数(如Python内置函数或math
模块中的函数)不受__torch_function__
覆盖,但我们仍希望在符号追踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrtdef normalize(x):"""Normalize `x` by the size of the batch dimension"""return x / sqrt(len(x))# It's valid Python code
normalize(torch.rand(3, 4))traced = torch.fx.symbolic_trace(normalize)
"""<...>File "sqrt.py", line 9, in normalizereturn x / sqrt(len(x))File "pytorch/torch/fx/proxy.py", line 161, in __len__raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误提示表明内置函数 len
不被支持。
我们可以通过 wrap()
API 将此类函数记录为跟踪中的直接调用:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')traced = torch.fx.symbolic_trace(normalize)print(traced.code)
"""
import math
def forward(self, x):len_1 = len(x)sqrt_1 = math.sqrt(len_1); len_1 = Nonetruediv = x / sqrt_1; x = sqrt_1 = Nonereturn truediv
"""
使用 Tracer
类自定义追踪功能
Tracer
类是 symbolic_trace
功能的基础实现类。通过继承 Tracer 类,可以自定义追踪行为,例如:
class MyCustomTracer(torch.fx.Tracer):# Inside here you can override various methods# to customize tracing. See the `Tracer` API# referencepass# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):def forward(self, x):return torch.relu(x) + torch.ones(3, 4)mod = MyModule()traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶子模块
叶子模块是指在符号追踪过程中作为调用出现,而不会被继续追踪的模块。默认的叶子模块集合由标准torch.nn
模块实例组成。例如:
class MySpecialSubmodule(torch.nn.Module):def forward(self, x):return torch.neg(x)class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(3, 4)self.submod = MySpecialSubmodule()def forward(self, x):return self.submod(self.linear(x))traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):linear_1 = self.linear(x); x = Noneneg_1 = torch.neg(linear_1); linear_1 = Nonereturn neg_1
"""
可以通过重写 Tracer.is_leaf_module()
来自定义叶子模块集合。
杂项说明
- 当前无法追踪张量构造函数(如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
):- 确定性构造函数(
zeros
、ones
)仍可使用,其生成的值会作为常量嵌入追踪记录。仅当这些构造函数的参数涉及动态输入大小时才会出现问题,此时可改用ones_like
或zeros_like
作为替代方案。 - 非确定性构造函数(
rand
、randn
)会将单个随机值嵌入追踪记录,这通常不符合预期行为。变通方法是将torch.randn
包装在torch.fx.wrap
函数中并调用该包装函数。
- 确定性构造函数(
(注:保留所有代码块及技术术语原貌,被动语态转为主动表述,长句拆分后保持技术严谨性)
@torch.fx.wrap
def torch_randn(x, shape):return torch.randn(shape)def f(x):return x + torch_randn(x, 5)
fx.symbolic_trace(f)
此行为可能在未来的版本中修复。
- 类型注解
-
支持 Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -torch.Tensor
),
并且会通过符号追踪保留这些注解。 -
目前不支持 Python 2 风格的注释类型注解
# type: (torch.Tensor, int) -torch.Tensor
。 -
目前不支持函数内部局部变量的类型注解。
- 关于
training
标志和子模块的注意事项
- 当使用像
torch.nn.functional.dropout
这样的函数时,通常会传入self.training
作为训练参数。在 FX 追踪过程中,这个值很可能会被固定为一个常量。
import torch
import torch.fxclass DropoutRepro(torch.nn.Module):def forward(self, x):return torch.nn.functional.dropout(x, training=self.training)traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
"""
def forward(self, x):dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = Nonereturn dropout
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
"""
AssertionError: Tensor-likes are not close!Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
"""
然而,当使用标准的 nn.Dropout()
子模块时,训练标志会被封装起来,并且由于保留了 nn.Module
对象模型,可以对其进行修改。
class DropoutRepro2(torch.nn.Module):def __init__(self):super().__init__()self.drop = torch.nn.Dropout()def forward(self, x):return self.drop(x)traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
"""
def forward(self, x):drop = self.drop(x); x = Nonereturn drop
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
由于这一差异,建议将与动态training
标志交互的模块标记为叶模块。
API 参考
torch.fx.symbolic_trace(root, concrete_args=None)
符号追踪 API
给定一个 nn.Module
或函数实例 root
,该 API 会返回一个 GraphModule
,这是通过记录追踪 root
时观察到的操作构建而成的。
concrete_args
参数允许你对函数进行部分特化,无论是为了移除控制流还是数据结构。
例如:
def f(a, b):if b == True:return a else:return a * 2
由于控制流的存在,FX通常无法追踪此过程。不过,我们可以使用concrete_args
来针对变量b的值进行特化处理,从而实现追踪:
f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6
请注意,虽然您仍可以传入不同的b值,但这些值将被忽略。
我们还可以使用concrete_args
来消除函数中对数据结构的处理。这将利用pytrees来展平您的输入。为了避免过度特化,对于不应特化的值,请传入fx.PH
。例如:
def f(x):out = 0for v in x.values():out += vreturn outf = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
参数
root (Union[torch.nn.Module, Callable])
- 待追踪并转换为图表示形式的模块或函数concrete_args (Optional[Dict[str, any]])
- 需要部分特化的输入参数
返回从root
记录的操作所创建的模块。
返回类型:GraphModule
注意:此API保证向后兼容性。
torch.fx.wrap(fn_or_name)
该函数可在模块级作用域调用,将fn_or_name
注册为"叶子函数"。
"叶子函数"在FX跟踪中会保留为CallFunction节点,而不会被进一步跟踪。
# foo/bar/baz.py
def my_custom_function(x, y):return x * x + y * ytorch.fx.wrap("my_custom_function")def fn_to_be_traced(x, y):# When symbolic tracing, the below call to my_custom_function will be inserted into# the graph rather than tracing it.return my_custom_function(x, y)
该函数也可以等效地用作装饰器:
# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):return x * x + y * y
包装函数可以被视为"叶子函数",类似于"叶子模块"的概念,也就是说,这些函数在FX跟踪中会保留为调用点,而不会被进一步追踪。
参数
fn_or_name (Union[str, Callable])
- 当被调用时,要插入到图中的函数或全局函数名称
注意:此API保证向后兼容性。
class torch.fx.GraphModule(*args, **kwargs)
GraphModule 是由 fx.Graph 生成的 nn.Module。GraphModule 具有一个 graph
属性,以及从该 graph
生成的 code
和 forward
属性。
警告:当重新分配 graph
时,code
和 forward
将自动重新生成。但如果你编辑了 graph
的内容而没有重新分配 graph
属性本身,则必须调用 recompile()
来更新生成的代码。
注意:此 API 保证向后兼容性。
__init__(root, graph, class_name='GraphModule')
构建一个 GraphModule。
参数
root (Union[torch.nn.Module , Dict[str, Any])
–root
可以是 nn.Module 实例,也可以是将字符串映射到任意属性类型的字典。
当 root
是 Module 时,Graph 的 Nodes 中 target
字段对基于 Module 的对象(通过限定名称引用)的任何引用,都会从 root
的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。
当 root
是字典时,Node 的 target
中找到的限定名称将直接在字典的键中查找。字典映射的对象将被复制到 GraphModule 模块层次结构中的适当位置。
graph (Graph)
–graph
包含此 GraphModule 用于代码生成的节点class_name (str)
–name
表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自GraphModule
。将其设置为root
的原始名称或在转换上下文中合理的名称可能会有所帮助。
注意:此 API 保证向后兼容性。
add_submodule(target, m)
将给定的子模块添加到self
中。
如果target
是子路径且对应位置尚未存在模块,此方法会安装空的模块。
参数
target (str)
- 新子模块的完整限定字符串名称
(参见nn.Module.get_submodule
中的示例了解如何指定完整限定字符串)m (Module)
- 子模块本身;即我们想要安装到当前模块中的实际对象
返回
子模块是否能够被插入。要使该方法返回True,target
表示的链中每个对象必须满足以下条件之一:
a) 尚不存在,或
b) 引用的是nn.Module
(而非参数或其他属性)
返回类型:bool
注意:此API保证向后兼容性。
property code: str
返回从该 GraphModule
底层 Graph
生成的 Python 代码。
delete_all_unused_submodules()
***
Deletes all unused submodules from `self`.A Module is considered “used” if any one of the following is true:
1、It has children that are used
2、Its forward is called directly via a `call_module` node
3、It has a non-Module attribute that is used from a `get_attr` nodeThis method can be called to clean up an `nn.Module` without
manually calling `delete_submodule` on each unused submodule.
***
Note: Backwards-compatibility for this API is guaranteed.delete_submodule(target)
从self
中删除指定的子模块。
如果target
不是有效的目标,则不会删除该模块。
参数
target (str)
- 新子模块的完全限定字符串名称
(有关如何指定完全限定字符串的示例,请参阅nn.Module.get_submodule
)
返回值
表示目标字符串是否引用了我们要删除的子模块。返回值为False
意味着target
不是有效的子模块引用。
返回类型 : bool
注意:此API保证向后兼容性。
property graph: [Graph](https://pytorch.org/docs/stable/data.html#torch.fx.Graph "torch.fx.graph.Graph")
返回该 GraphModule
底层对应的 Graph
print_readable(print_output=True, include_stride=False, include_device=False, colored=False)
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码
警告:此 API 为实验性质,且不保证向后兼容性。
recompile()
根据其 graph
属性重新编译该 GraphModule。在编辑包含的 graph
后应调用此方法,否则该 GraphModule
生成的代码将过期。
注意:此 API 保证向后兼容性。
返回类型:PythonCode
to_folder(folder, module_name='FxModule')
将模块以 module_name
名称转储到 folder
目录下,以便可以通过 from <folder> import <module_name>
方式导入。
参数:
folder (Union [str, os.PathLike])
: 用于输出代码的目标文件夹路径
module_name (str): 在输出代码时使用的顶层模块名称
警告:此 API 为实验性质,不保证向后兼容性。
class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)
Graph
是 FX 中间表示中使用的主要数据结构。
它由一系列 Node
组成,每个节点代表调用点(或其他语法结构)。这些 Node
的集合共同构成了一个有效的 Python 函数。
例如,以下代码
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)
将生成以下图表:
print(gm.graph)
graph(x):%linear_weight : [num_users=1] = self.linear.weight%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})return topk_1
关于Graph
中操作的具体语义,请参阅Node
文档。
注意:本API保证向后兼容性。
__init__(owning_module=None, tracer_cls=None, tracer_extras=None)
构建一个空图。
注意:此 API 保证向后兼容性。
call_function(the_function, args=None, kwargs=None, type_expr=None)
在Graph
中插入一个call_function
类型的Node
。call_function
节点表示对Python可调用对象的调用,由the_function
指定。
参数
the_function (Callable[...*, Any])
– 要调用的函数。可以是任何PyTorch运算符、Python函数,或属于builtins
或operator
命名空间的成员。args (Optional[Tuple[Argument*, ...]])
– 传递给被调用函数的位置参数。kwargs (Optional[Dict[str, Argument]])
– 传递给被调用函数的关键字参数。type_expr (Optional[Any])
– 可选的类型注解,表示该节点输出值的Python类型。
返回
新创建并插入的call_function
节点。
返回类型
Node
注意:此方法的插入点和类型表达式规则与Graph.create_node()
相同。
注意:此API保证向后兼容性。
call_method(method_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一个call_method
节点。call_method
节点表示对args
第0个元素调用指定方法。
参数
method_name (str)
- 要应用于self参数的方法名称。例如,如果args[0]是一个表示Tensor
的Node
,那么要对该Tensor
调用relu()
方法时,需将relu
作为method_name
传入。args (Optional[Tuple[Argument*, ...]])
- 要传递给被调用方法的位置参数。注意这应该包含一个self参数。kwargs (Optional[Dict[str, Argument]])
- 要传递给被调用方法的关键字参数type_expr (Optional[Any])
- 可选的类型注解,表示该节点输出结果的Python类型。
返回
新创建并插入的call_method
节点。
返回类型
Node
注意:本方法的插入点和类型表达式规则与Graph.create_node()
相同。
注意:此API保证向后兼容性。
call_module(module_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一个call_module
类型的Node
节点。call_module
节点表示对Module
层级结构中某个Module
的forward()函数的调用。
参数
module_name (str)
- 要调用的Module
在层级结构中的限定名称。例如,若被追踪的Module
有一个名为foo
的子模块,而该子模块又包含名为bar
的子模块,则应以foo.bar
作为module_name
来调用该模块。args (Optional[Tuple[Argument*, ...]])
- 传递给被调用方法的位置参数。注意:此处不应包含self
参数。kwargs (Optional[Dict[str, Argument]])
- 传递给被调用方法的关键字参数type_expr (Optional[Any])
- 可选类型注解,表示该节点输出值的Python类型。
返回
新创建并插入的call_module
节点。
返回类型:Node
注意:本方法的插入点与类型表达式规则与Graph.create_node()
相同。
注意:本API保证向后兼容性。
create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)
创建一个 Node
并将其添加到当前插入点的 Graph
中。
注意:当前插入点可以通过 Graph.inserting_before()
和 Graph.inserting_after()
进行设置。
参数
op (str)
- 该节点的操作码。可选值包括 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。这些操作码的语义在Graph
的文档字符串中有详细说明。args (Optional[Tuple[Argument*, ...]])
- 该节点的参数元组。kwargs (Optional[Dict[str, Argument]])
- 该节点的关键字参数。name (Optional[str])
- 为Node
指定的可选字符串名称。这将影响生成的 Python 代码中赋值给该节点的变量名。type_expr (Optional[Any])
- 可选类型注解,表示该节点输出值的 Python 类型。
返回
新创建并插入的节点。
返回类型:Node
注意:此 API 保证向后兼容。
eliminate_dead_code(is_impure_node=None)
根据图中各节点的用户数量及是否具有副作用,移除所有死代码。调用前必须确保图已完成拓扑排序。
参数
is_impure_node (Optional[Callable[[Node],* [bool]]])
—— 用于判断节点是否为非纯函数的回调函数。若未提供该参数,则默认使用Node.is_impure
方法。
返回值:返回布尔值,表示该过程是否导致图结构发生变更。
返回类型:bool
示例
在消除死代码前,下方表达式 a = x + 1
中的变量 a
无用户引用,因此可从图中安全移除而不影响结果。
def forward(self, x):a = x + 1return x + self.attr_1
消除死代码后,a = x + 1
已被移除,前向传播部分的其他代码保留不变。
def forward(self, x):return x + self.attr_1
警告:死代码消除机制虽然采用了一些启发式方法来避免删除具有副作用的节点(参见 Node.is_impure
),但总体覆盖率非常不理想。因此,除非你明确知道当前 FX 计算图完全由无副作用的操作构成,或者自行提供了检测副作用节点的自定义函数,否则不应假设调用此方法是安全可靠的。
注意:本 API 保证向后兼容性。
erase_node(to_erase)
从Graph
中删除一个Node
。如果该节点在Graph
中仍被使用,将抛出异常。
参数
to_erase (Node)
– 要从Graph
中删除的Node
。
注意:此API保证向后兼容性。
find_nodes(*, op, target=None, sort=True)
支持快速查询节点
参数
op (str)
– 操作名称target (Optional[Target])
– 节点目标。对于call_function操作,target为必填项;其他操作中target为可选参数。sort ([bool])
– 是否按节点在图中出现的顺序返回结果。
返回值:返回符合指定op和target条件的节点迭代器。
警告:此API为实验性质,且不保证向后兼容。
get_attr(qualified_name, type_expr=None)
向图中插入一个 get_attr
节点。get_attr
类型的 Node
表示从 Module
层次结构中获取某个属性。
参数
qualified_name (str)
- 要获取属性的全限定名称。例如,若被追踪的 Module 包含名为foo
的子模块,该子模块又包含名为bar
的子模块,而bar
拥有名为baz
的属性,则应将全限定名称foo.bar.baz
作为qualified_name
传入。type_expr (Optional[Any])
- 可选的类型注解,用于表示该节点输出值的 Python 类型。
返回
新创建并插入的 get_attr
节点。
返回类型:Node
注意:本方法的插入点与类型表达式规则与 Graph.create_node
方法保持一致。
注意:此 API 保证向后兼容性。
graph_copy(g, val_map, return_output_node=False)
将给定图中的所有节点复制到 self
中。
参数
g (Graph)
– 作为节点复制来源的原始图。val_map (Dict[Node,* Node])
– 用于存储节点映射关系的字典,键为g
中的节点,值为self
中的对应节点。注意:val_map
可预先包含值以实现特定值的复制覆盖。
返回值:如果 g
包含输出节点,则返回 self
中与 g
输出值等效的值;否则返回 None
。
返回类型:Optional[Union [tuple [Argument, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , [dtype](tensor_attributes.html#torch.dtype "torch.dtype"), Tensor , device , memory_format , layout , OpOverload, [SymInt](torch.html#torch.SymInt "torch.SymInt"), SymBool , SymFloat ]]
注意:本API保证向后兼容性。
inserting_after(n=None)
设置 create_node
及相关方法在图中插入节点的位置。当在 with
语句中使用时,这会临时设置插入点,并在 with
语句退出时恢复原位置。
with g.inserting_after(n):... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) # set the insert point permanently
参数:
n (可选[Node]): 要在其之前插入的节点。如果为None,则会在整个图的起始位置之后插入。
返回:
一个资源管理器,它会在__exit__
时恢复插入点。
注意:此API保证向后兼容性。
inserting_before(n=None)
设置 create_node
及相关方法在图中插入节点的基准位置。当在 with
语句中使用时,这将临时设置插入点,并在 with
语句退出时恢复原位置。
with g.inserting_before(n):... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) # set the insert point permanently
参数:
n (Optional[Node]): 要插入位置的前一个节点。如果为None,则会在整个图的起始位置前插入。
返回:
一个资源管理器,该管理器会在__exit__
时恢复插入点。
注意:此API保证向后兼容性。
lint()
对该图执行多项检查以确保其结构正确。具体包括:
- 检查节点是否具有正确的所有权(由本图所有)
- 检查节点是否按拓扑顺序排列
- 若该图拥有所属的GraphModule,则检查目标是否存在该GraphModule中
注:本API保证向后兼容性。
node_copy(node, arg_transform=<function Graph.<lambda>>)
将节点从一个图复制到另一个图中。arg_transform
需要将节点所在图的参数转换为目标图(self)的参数。示例:
# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}for node in g.nodes:value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
参数
node (Node)
– 要复制到self
中的节点。arg_transform (Callable[[Node], Argument])
– 一个函数,用于将节点args
和kwargs
中的Node
参数转换为self
中的等效参数。最简单的情况下,该函数应从原始图中节点到self
的映射表中检索值。
返回类型:Node
注意:此 API 保证向后兼容性。
property nodes: _node_list
获取构成该图的所有节点列表。
请注意,这个Node
列表是以双向链表的形式表示的。在迭代过程中进行修改(例如删除节点、添加节点)是安全的。
返回值:一个双向链表结构的节点列表。注意可以对该列表调用reversed
方法来切换迭代顺序。
on_generate_code(make_transformer)
在生成 Python 代码时注册转换器函数
参数:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):返回待注册代码转换器的函数。
该函数由 on_generate_code 调用以获取代码转换器。
此函数的输入参数为当前已注册的代码转换器(若未注册则为 None),以便在不需要覆盖时使用。该机制可用于串联多个代码转换器。
返回值:一个上下文管理器,当在 with 语句中使用时,会自动恢复先前注册的代码转换器。
示例:
gm: fx.GraphModule = ...# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):return ["import pdb; pdb.set_trace()\n", body]# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(lambda current_trans: (lambda body: insert_pdb(current_trans(body) if current_trans else body))
)gm.recompile()
gm(inputs) # drops into pdb
该函数也可作为上下文管理器使用,其优势在于能自动恢复之前注册的代码转换器。
# ... continue from previous examplewith gm.graph.on_generate_code(lambda _: insert_pdb):# do more stuff with `gm`...gm.recompile()gm(inputs) # drops into pdb# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告:此 API 为实验性质,且不向后兼容。
output(result, type_expr=None)
将 output
Node
插入到 Graph
中。output
节点代表 Python 代码中的 return
语句。result
是应当返回的值。
参数
result (Argument)
– 要返回的值。type_expr (Optional[Any])
– 可选的类型注解,表示此节点输出将具有的 Python 类型。
注意:此方法的插入点和类型表达式规则与 Graph.create_node
相同。
注意:此 API 保证向后兼容性。
output_node()
警告:此 API 为实验性质,且不向后兼容。
返回值类型:Node
placeholder(name, type_expr=None, default_value)
在图中插入一个placeholder
节点。placeholder
表示函数的输入参数。
参数
name (str)
- 输入值的名称。这对应于该Graph
所表示函数的位置参数名称。type_expr (Optional[Any])
- 可选的类型注解,表示该节点输出值的Python类型。在某些情况下(例如当函数后续用于TorchScript编译时),这是生成正确代码所必需的。default_value (Any)
- 该函数参数的默认值。注意:为了允许None作为默认值,当参数没有默认值时,应传递inspect.Signature.empty来指定。
返回类型:Node
注意:此方法的插入点和类型表达式规则与Graph.create_node
相同。
注意:此API保证向后兼容性。
print_tabular()
以表格形式打印图的中间表示。注意:此API需要安装tabulate
模块。
注:该API保证向后兼容性。
process_inputs(*args)
处理参数以便它们可以传递到 FX 计算图中。
警告:此 API 为实验性质,且不向后兼容。
process_outputs(out)
警告:此 API 为实验性质,且不向后兼容。
python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)
将这段Graph
转换为有效的Python代码。
参数
root_module (str)
– 用于查找限定名称目标的根模块名称。通常为’self’。
返回值:src: 表示该对象的Python源代码
globals: 包含src中全局名称及其引用对象的字典
返回类型:一个包含两个字段的PythonCode对象
注意:此API保证向后兼容性。
set_codegen(codegen)
警告:此 API 为实验性功能,且不向后兼容。
class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)
Node
是表示 Graph
中单个操作的数据结构。在大多数情况下,Node 表示对各种实体的调用点,例如运算符、方法和模块(某些例外包括指定函数输入和输出的节点)。每个 Node
都有一个由其 op
属性指定的函数。不同 op
值的 Node
语义如下:
placeholder
表示函数输入。name
属性指定该值的名称。target
同样是参数的名称。args
包含:1) 空值,或 2) 表示函数输入默认参数的单个参数。kwargs
无关紧要。占位符对应于图形输出中的函数参数(例如x
)。get_attr
从模块层次结构中检索参数。name
同样是获取结果后赋值的名称。target
是参数在模块层次结构中的完全限定名称。args
和kwargs
无关紧要。call_function
将自由函数应用于某些值。name
同样是赋值目标的名称。target
是要应用的函数。args
和kwargs
表示函数的参数,遵循 Python 调用约定。call_module
将模块层次结构中的forward()
方法应用于给定参数。name
同前。target
是要调用的模块在模块层次结构中的完全限定名称。args
和kwargs
表示调用模块时的参数(不包括 self 参数*)。call_method
调用值的方法。name
类似。target
是要应用于self
参数的方法名称字符串。args
和kwargs
表示调用模块时的参数(包括 self 参数*)。output
在其args[0]
属性中包含跟踪函数的输出。这对应于图形输出中的 “return” 语句。
注意:此 API 保证向后兼容。
property all_input_nodes: list ['Node']
Return all Nodes that are inputs to this Node. This is equivalent to iterating over `args` and `kwargs` and only collecting the values that are Nodes.Returns
List of `Nodes` that appear in the `args` and `kwargs` of this `Node`, in that order.append(x)
在图的节点列表中,将 x
插入到当前节点之后。
等价于调用 self.next.prepend(x)
参数
x (Node)
– 要插入到当前节点后的节点。必须属于同一个图。
注意:此 API 保证向后兼容。
property args: tuple [Union [tuple ['Argument',
...], collections.abc.Sequence ['Argument'], collections.abc.Mapping[str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType],
...]
该Node
的参数元组。参数的具体含义取决于节点的操作码(opcode)。更多信息请参阅Node
文档字符串。
允许对此属性进行赋值操作。所有关于使用情况和用户的记录都会在赋值时自动更新。
format_node(placeholder_names=None, maybe_return_typename=None)
返回一个描述性的字符串表示形式self
。
该方法可不带参数使用,作为调试工具。
此函数也用于Graph
的__str__
方法内部。placeholder_names
和maybe_return_typename
中的字符串共同构成了该Graph所属GraphModule中自动生成的forward
函数的签名。placeholder_names
和maybe_return_typename
不应在其他情况下使用。
参数
placeholder_names (Optional[list[str]])
- 一个列表,用于存储表示生成的forward
函数中占位符的格式化字符串。仅供内部使用。maybe_return_typename (Optional[list[str]])
- 一个单元素列表,用于存储表示生成的forward
函数输出的格式化字符串。仅供内部使用。
返回
如果1)我们在Graph
的__str__
方法中将format_node
用作内部辅助工具,且2)self
是一个占位符Node,则返回None
。否则,返回当前Node的描述性字符串表示形式。
返回类型:str
注意:此API保证向后兼容。
insert_arg(idx, arg)
在参数列表的指定索引位置插入一个位置参数。
参数
idx ( int )
– 要插入到self.args
中元素之前的索引位置。arg (Argument)
– 要插入到args
中的新参数值
注意:本API保证向后兼容性。
is_impure()
返回该操作是否为不纯操作,即判断其操作是否为占位符或输出,或者是否为不纯的call_function
或call_module
。
返回值:指示该操作是否不纯。
返回类型:bool
警告:此API为实验性质,且不向后兼容。
property kwargs: dict[str , Union [tuple ['Argument',
...], collections.abc.Sequence['Argument'], collections.abc.Mapping, [str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]
该Node
的关键字参数字典。参数的解析取决于节点的操作码。更多信息请参阅Node
文档字符串。
允许对此属性进行赋值。所有关于使用情况和用户的统计都会在赋值时自动更新。
property next: Node
返回链表中下一个Node
节点。
返回值:链表中下一个Node
节点。
normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)
返回经过标准化的Python目标参数。这意味着当normalize_to_only_use_kwargs
为真时,args/kwargs将与模块/函数的签名匹配,并按位置顺序仅返回kwargs。
同时会填充默认值。不支持仅限位置参数或可变参数。
支持模块调用。
可能需要arg_types
和kwarg_types
来消除重载歧义。
参数
root (torch.nn.Module)
– 用于解析模块目标的基模块arg_types (Optional[Tuple[Any]])
– 参数的元组类型kwarg_types (Optional[Dict[str, Any]])
– 关键字参数的字典类型normalize_to_only_use_kwargs ([bool])
– 是否标准化为仅使用kwargs
返回
返回命名元组ArgsKwargsPair
,若失败则返回None
返回类型
Optional[ArgsKwargsPair]
警告:该API为实验性质,不保证向后兼容。
prepend(x)
在图的节点列表中,在此节点前插入x。示例:
Before: p -selfbx -x -ax
After: p -x -selfbx -ax
参数
x (Node)
– 要放置在该节点之前的节点。必须是同一图的成员。
注意:此 API 保证向后兼容。
property prev: Node
返回链表中当前节点的前一个Node
。
返回值:链表中当前节点的前一个Node
。
replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)
将图中所有使用 self
的地方替换为节点 replace_with
。
参数
replace_with (Node)
– 用于替换所有self
的节点。delete_user_cb (Callable)
– 回调函数,用于判断是否应移除某个使用原self
节点的用户节点。propagate_meta ([bool])
– 是否将原节点.meta
字段的所有属性复制到替换节点上。出于安全考虑,仅当替换节点本身没有.meta
字段时才允许此操作。
返回值
返回受此变更影响的节点列表。
返回类型:list [Node]
注意:此 API 保证向后兼容。
replace_input_with(old_input, new_input)
遍历 self
的输入节点,将所有 old_input
实例替换为 new_input
。
参数
old_input (Node)
– 需要被替换的旧输入节点。new_input (Node)
– 用于替换old_input
的新输入节点。
注意:此 API 保证向后兼容性。
property stack_trace: Optional[str ]
返回在追踪过程中记录的 Python 堆栈跟踪信息(如果有)。
当使用 fx.Tracer
进行追踪时,该属性通常由 Tracer.create_proxy
填充。若需在追踪过程中记录堆栈跟踪以用于调试,请在 Tracer 实例上设置 record_stack_traces = True
。
当使用 dynamo 进行追踪时,该属性默认会由 OutputGraph.create_proxy
填充。
stack_trace
的字符串末尾将包含最内层的调用帧。
update_arg(idx, arg)
更新现有位置参数以包含新值
调用后,self.args[idx] == arg
将成立。
参数
idx ( int )
- 要更新元素在self.args
中的索引位置arg (Argument)
- 要写入args
的新参数值
注意:此 API 保证向后兼容性。
update_kwarg(key, arg)
更新现有关键字参数以包含新值
arg
。调用后,self.kwargs[key] == arg
。
参数
key (str)
- 要更新的元素在self.kwargs
中的键名arg (Argument)
- 要写入kwargs
的新参数值
注意:此API保证向后兼容性。
class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())
Tracer
是实现 torch.fx.symbolic_trace
符号追踪功能的类。调用 symbolic_trace(m)
等价于执行 Tracer().trace(m)
。
可以通过继承 Tracer 类来覆盖追踪过程中的各种行为。具体可覆盖的行为详见该类方法的文档字符串。
注意:此 API 保证向后兼容。
call_module(m, forward, args, kwargs)
该方法定义了当Tracer
遇到对nn.Module
实例调用时的行为。
默认行为是通过is_leaf_module
检查被调用的模块是否为叶子模块。如果是,则在Graph
中生成指向m
的call_module
节点;否则正常调用该Module
,并跟踪其forward
函数中的操作。
可通过重写此方法实现自定义行为,例如:
- 创建嵌套的追踪GraphModules
- 实现跨
Module
边界追踪时的特殊处理
参数说明:
m (Module)
- 当前被调用的模块实例forward (Callable)
- 待调用模块的forward()方法args (Tuple)
- 模块调用点的参数元组kwargs (Dict)
- 模块调用点的关键字参数字典
返回值:
- 若生成
call_module
节点,则返回Proxy
代理值 - 否则返回模块调用的原始结果
返回类型:任意类型
注意:本API保证向后兼容性。
create_arg(a)
一种方法,用于指定在准备值作为Graph
中节点的参数时追踪的行为。
默认行为包括:
1、遍历集合类型(如元组、列表、字典)并递归地对元素调用create_args
。
2、给定一个Proxy对象,返回底层IR Node
的引用。
3、给定一个非Proxy的Tensor对象,为以下情况生成IR:
- 对于Parameter,生成一个引用该Parameter的
get_attr
节点。 - 对于非Parameter的Tensor,将该Tensor存储在一个特殊属性中,并引用该属性。
可以重写此方法以支持更多类型。
参数
a (Any)
– 将被作为Argument
在Graph
中使用的值。
返回值:将值a
转换为适当的Argument
。
返回类型:Argument
注意:此API保证向后兼容。
create_args_for_root(root_fn, is_module, concrete_args=None)
为root
模块的签名创建对应的placeholder
节点。该方法会检查root模块的签名并据此生成这些节点,同时支持*args
和**kwargs
参数。
警告:此API为实验性质,且不向后兼容。
create_node(kind, target, args, kwargs, name=None, type_expr=None)
根据给定的目标、参数、关键字参数和名称插入一个图节点。
该方法可以被重写,用于在节点创建过程中对使用的值进行额外检查、验证或修改。例如,可能希望禁止记录原地操作。
注意:此API保证向后兼容性。
返回类型:Node
create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)
根据给定的参数创建一个节点,然后返回包裹在 Proxy 对象中的节点。
如果 kind = ‘placeholder’,则表示我们正在创建一个代表函数参数的节点。若需要编码默认参数,则使用 args
元组。对于 placeholder
类型的节点,args
在其他情况下为空。
注意:此 API 保证向后兼容性。
get_fresh_qualname(prefix)
获取一个基于前缀的新名称并返回。该函数确保生成的名称不会与图中现有属性发生冲突。
注意:此API保证向后兼容。
返回类型:str
getattr(attr, attr_val, parameter_proxy_cache)
该方法定义了当对nn.Module
实例调用getattr时,该Tracer
的行为表现。
默认情况下,其行为是返回该属性的代理值。同时会将代理值存入parameter_proxy_cache
中,以便后续调用能复用该代理而非新建。
可通过重写此方法来实现不同行为——例如在查询参数时不返回代理。
参数说明:
attr (str)
- 被查询的属性名称attr_val (Any)
- 该属性的值parameter_proxy_cache (Dict[str, Any])
- 属性名到代理值的映射缓存
返回值:
getattr调用的返回结果。
警告:此API属于实验性质,且不保证向后兼容。
is_leaf_module(m, module_qualified_name)
一种用于判断给定nn.Module
是否为"叶子"模块的方法。
叶子模块是指出现在IR(中间表示)中的原子单元,通过call_module
调用进行引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块都属于叶子模块。除非通过本参数特别指定,否则其他所有模块都会被追踪并记录其组成操作。
参数说明:
m (Module)
- 被查询的模块module_qualified_name (str)
- 该模块到根模块的路径。例如,若模块层级结构中子模块foo
包含子模块bar
,而bar
又包含子模块baz
,则该模块的限定名将显示为foo.bar.baz
返回类型:bool
注意:本API保证向后兼容性。
iter(obj)
当代理对象被迭代时调用,例如在控制流中使用时。通常我们不知道如何处理,因为我们不知道代理的值,但自定义跟踪器可以通过 create_node
向图节点附加更多信息,并可以选择返回一个迭代器。
注意:此 API 保证向后兼容性。
返回类型:迭代器
keys(obj)
当代理对象的 keys()
方法被调用时触发。这是在代理对象上调用 **
时发生的情况。该方法应返回一个迭代器,如果 **
需要在自定义追踪器中生效。
注意:此 API 保证向后兼容。
返回类型:任意
path_of_module(mod)
这是一个辅助方法,用于在root
模块的层级结构中查找mod
的限定名称。例如,如果root
有一个名为foo
的子模块,而foo
又有一个名为bar
的子模块,那么将bar
传入此函数将返回字符串"foo.bar"。
参数
mod (str)
– 需要获取限定名称的Module
。
返回类型:str
注意:此API保证向后兼容性。
proxy(node)
注意:此 API 保证向后兼容性。
返回类型:Proxy
to_bool(obj)
当代理对象需要转换为布尔值时调用,例如在控制流中使用时。通常我们无法确定如何处理,因为不知道代理的具体值,但自定义追踪器可以通过create_node
向图节点附加更多信息,并选择返回一个值。
注意:此API保证向后兼容。
返回类型:bool
trace(root, concrete_args=None)
追踪 root
并返回对应的 FX Graph
表示形式。root
可以是 nn.Module
实例或 Python 可调用对象。
请注意,在此调用后,self.root
可能与传入的 root
不同。例如,当向 trace()
传递自由函数时,我们会创建一个 nn.Module
实例作为根节点,并添加嵌入的常量。
参数
root (Union[Module, Callable])
– 需要追踪的Module
或函数。该参数保证向后兼容性。concrete_args (Optional[Dict[str, any]])
– 不应被视为代理的具体参数。此参数为实验性功能,其向后兼容性不作保证。
返回值:表示传入 root
语义的 Graph
对象。
返回类型:Graph
注意:此 API 保证向后兼容性。
class torch.fx.Proxy(node, tracer=None)
Proxy
对象是Node
包装器,在符号追踪过程中流经程序,并记录它们接触到的所有操作(包括torch
函数调用、方法调用和运算符)到不断增长的FX Graph中。
如果需要进行图变换,您可以在原始Node
上封装自己的Proxy
方法,这样就可以使用重载运算符向Graph
添加额外内容。
Proxy
对象不可迭代。换句话说,如果在循环中或作为*args
/**kwargs
函数参数使用Proxy
,符号追踪器会抛出错误。
有两种主要解决方法:
1、将不可追踪的逻辑提取到顶层函数中,并使用fx.wrap
进行处理。
2、如果控制流是静态的(即循环次数基于某些超参数),可以保持代码在原位,并重构为类似形式:
for i in range(self.some_hyperparameter):indexed_item = proxied_value[i]
如需更深入了解 Proxy 的内部实现细节,请查阅 torch/fx/README.md 文件中的 “Proxy” 章节。
注意:本 API 保证向后兼容性。
class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)
解释器(Interpreter)会逐节点(Node-by-Node)执行FX图。这种模式在许多场景下非常有用,包括编写代码转换器以及分析过程。
通过重写Interpreter类中的方法,可以自定义执行行为。以下是按调用层次结构划分的可重写方法映射:
run()+-- run_node+-- placeholder()+-- get_attr()+-- call_function()+-- call_method()+-- call_module()+-- output()
示例
假设我们需要将所有 torch.neg
实例与 torch.sigmoid
互换(包括它们对应的 Tensor
方法等价形式)。我们可以通过如下方式继承 Interpreter 类:
class NegSigmSwapInterpreter(Interpreter):def call_function(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == torch.sigmoid:return torch.neg(args, *kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == "neg":call_self, args_tail = argsreturn call_self.sigmoid(args_tail, *kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
module ( torch.nn.Module )
– 待执行的模块garbage_collect_values ([bool])
– 是否在模块执行过程中最后一次使用后删除值。这能确保执行期间内存使用最优。可以禁用此功能,例如通过查看Interpreter.env
属性来检查执行中的所有中间值。graph (Optional[Graph])
– 如果传入该参数,解释器将执行此图而非module.graph,并使用提供的模块参数来满足任何状态请求。
注意:此API保证向后兼容性。
boxed_run(args_list)
通过解释方式运行模块并返回结果。该过程采用"boxed"调用约定,即传递一个参数列表(这些参数会被解释器自动清除),从而确保输入张量能够及时释放。
注意:本API保证向后兼容性。
call_function(target, args, kwargs)
执行一个call_function
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回类型:任意类型
返回值: 函数调用返回的值
注意:此API保证向后兼容性。
call_method(target, args, kwargs)
执行一个 call_method
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。有关语义的详细信息,请参阅 Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型:任意
返回值:方法调用返回的值
注意:此 API 保证向后兼容性。
call_module(target, args, kwargs)
执行一个call_module
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅
Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回类型:Any
返回值:模块调用返回的值
注意:此API保证向后兼容性。
fetch_args_kwargs_from_env(n)
从当前执行环境中获取节点n
的args
和kwargs
具体值
参数
n (Node)
– 需要获取args
和kwargs
的目标节点
返回值
节点n
对应的具体args
和kwargs
值
返回类型:Tuple[Tuple, Dict]
注意:本API保证向后兼容性
fetch_attr(target)
从 self.module
的 Module
层级结构中获取一个属性。
参数
target (str)
- 要获取属性的全限定名称
返回
该属性的值。
返回类型
任意类型
注意:此 API 保证向后兼容。
get_attr(target, args, kwargs)
执行一个 get_attr
节点。该操作会从 self.module
的 Module
层级结构中获取属性值。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅 Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值
获取到的属性值
返回类型
任意类型
注意:此 API 保证向后兼容性。
map_nodes_to_values(args, n)
递归遍历 args
并在当前执行环境中查找每个 Node
的具体值。
参数
args (Argument)
– 需要查找具体值的数据结构n (Node)
–args
所属的节点。仅用于错误报告。
返回类型:Optional[Union [tuple [Argument’, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , dtype, Tensor , device , memory_format , layout , OpOverload, SymInt, SymBool , SymFloat ]]
注意:此 API 保证向后兼容性。
output(target, args, kwargs)
执行一个output
节点。该操作实际上只是获取output
节点引用的值并返回它。
参数
target (Target)
– 该节点的调用目标。有关语义详情请参阅
Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值:输出节点引用的返回值
返回类型:任意类型
注意:此API保证向后兼容。
placeholder(target, args, kwargs)
执行一个placeholder
节点。请注意这是有状态的:
Interpreter
内部维护了一个针对run
方法传入参数的迭代器,本方法会返回该迭代器的next()结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值:获取到的参数值。
返回类型:任意类型
注意:此API保证向后兼容。
run(*args, initial_env=None, enable_io_processing=True)
通过解释执行模块并返回结果。
参数
*args
– 按位置顺序传递给模块的运行参数initial_env (Optional[Dict[Node, Any]])
– 可选的执行初始环境。这是一个将节点映射到任意值的字典。例如,可用于预先填充某些节点的结果,从而在解释器中仅进行部分求值。enable_io_processing ([bool])
– 如果为true,我们会在使用输入和输出之前,先用图的process_inputs和process_outputs函数对它们进行处理。
返回值:执行模块后返回的值
返回类型:任意
注意:此API保证向后兼容。
run_node(n)
运行特定节点 n
并返回结果。
根据 node.op
的类型,调用对应的占位符、get_attr、call_function、call_method、call_module 或 output 方法。
参数
n (Node)
– 需要执行的节点
返回值:执行节点 n
的结果
返回类型:任意类型
注意:此 API 保证向后兼容性。
class torch.fx.Transformer(module)
Transformer
是一种特殊类型的解释器,用于生成新的 Module
。它提供了一个 transform()
方法,返回转换后的 Module
。与 Interpreter
不同,Transformer
不需要参数即可运行,完全基于符号化方式工作。
示例
假设我们需要将所有 torch.neg
实例与 torch.sigmoid
互换(包括它们的 Tensor
方法等效形式)。可以通过如下方式子类化 Transformer
:
class NegSigmSwapXformer(Transformer):def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == torch.sigmoid:return torch.neg(*args, **kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == "neg":call_self, *args_tail = argsreturn call_self.sigmoid(*args_tail, **kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数
module ([GraphModule](https://pytorch.org/docs/stable/data.html#torch.fx.GraphModule "torch.fx.GraphModule"))
– 待转换的Module
对象。
注意:此API保证向后兼容性。
call_function(target, args, kwargs)
注意:该 API 保证向后兼容。
返回类型
Any
call_module(target, args, kwargs)
注意:此 API 保证向后兼容。
返回类型
Any
get_attr(target, args, kwargs)
执行一个 get_attr
节点。在 Transformer
中,该方法被重写以便向输出图中插入新的 get_attr
节点。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅
Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型
Proxy
注意:此 API 保证向后兼容。
placeholder(target, args, kwargs)
执行一个 placeholder
节点。在 Transformer
中,该方法被重写以便向输出图中插入新的 placeholder
。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅 Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型:Proxy
注意:此 API 保证向后兼容。
transform()
转换 self.module
并返回转换后的 GraphModule
。
注意:此 API 保证向后兼容性。
返回类型 : GraphModule
torch.fx.replace_pattern(gm, pattern, replacement)
在GraphModule的图结构(gm
)中,匹配所有可能的非重叠运算符集及其数据依赖关系(pattern
),然后将每个匹配到的子图替换为另一个子图(replacement
)。
参数
gm (GraphModule)
- 封装待操作图的GraphModulepattern (Union[Callable, GraphModule])
- 需要在gm
中匹配并替换的子图replacement (Union[Callable, GraphModule])
- 用于替换pattern
的子图
返回值:返回一个Match
对象列表,表示原始图中与pattern
匹配的位置。如果没有匹配项则返回空列表。Match
定义如下:
class Match(NamedTuple):# Node from which the match was foundanchor: Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map: Dict[Node, Node]
返回类型:List[Match]
示例:
import torch
from torch.fx import symbolic_trace, subgraph_rewriterclass M(torch.nn.Module):def __init__(self) -None:super().__init__()def forward(self, x, w1, w2):m1 = torch.cat([w1, w2]).sum()m2 = torch.cat([w1, w2]).sum()return x + torch.max(m1) + torch.max(m2)def pattern(w1, w2):return torch.cat([w1, w2])def replacement(w1, w2):return torch.stack([w1, w2])traced_module = symbolic_trace(M())subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码会先在 traced_module
的 forward
方法中匹配 pattern
。模式匹配基于使用-定义关系而非节点名称进行。例如,若 pattern
中包含 p = torch.cat([a, b])
,则可以在原始 forward
函数中匹配到 m = torch.cat([a, b])
,即使变量名不同(p
与 m
)也不影响。
pattern
中的 return
语句仅根据其值进行匹配,它可能与更大图中的 return
语句匹配,也可能不匹配。换句话说,模式不必延伸至更大图的末尾。
当模式匹配成功时,它将从更大的函数中被移除,并由 replacement
替换。如果更大函数中存在多个 pattern
匹配项,每个非重叠的匹配项都会被替换。若出现匹配重叠的情况,则替换重叠匹配集中最先找到的匹配项(此处的"最先"定义为节点使用-定义关系拓扑排序中的第一个节点。大多数情况下,第一个节点是紧接 self
后出现的参数,而最后一个节点是函数返回的内容)。
需要特别注意:pattern
可调用对象的参数必须在该可调用对象内部使用,且 replacement
可调用对象的参数必须与模式匹配。第一条规则解释了为何上述代码块中 forward
函数有参数 x, w1, w2
,而 pattern
函数只有参数 w1, w2
——因为 pattern
未使用 x
,故不应将 x
指定为参数。
关于第二条规则的示例,考虑替换…
def pattern(x, y):return torch.neg(x) + torch.relu(y)
with
def replacement(x, y):return torch.relu(x)
在这种情况下,replacement
需要与pattern
相同数量的参数(包括x
和y
),即使参数y
在replacement
中并未使用。
调用subgraph_rewriter.replace_pattern
后,生成的Python代码如下所示:
def forward(self, x, w1, w2):stack_1 = torch.stack([w1, w2])sum_1 = stack_1.sum()stack_2 = torch.stack([w1, w2])sum_2 = stack_2.sum()max_1 = torch.max(sum_1)add_1 = x + max_1max_2 = torch.max(sum_2)add_2 = add_1 + max_2return add_2
注意:该 API 保证向后兼容。
torch.fx.experimental
警告:这些API属于实验性质,可能会随时变更而不另行通知。
torch.fx.experimental.symbolic_shapes
ShapeEnv | |
---|---|
DimDynamic | 控制如何为维度分配符号。 |
StrictMinMaxConstraint | 对客户端:该维度的大小必须在’vr’范围内(指定包含性上下界),且必须为非负数且不应为0或1(但参见下方注意事项)。 |
RelaxedUnspecConstraint | 对客户端:无显式约束;约束由追踪过程中的守卫隐式推断得出。 |
EqualityConstraint | 表示并判定输入源之间的各类相等性约束。 |
SymbolicContext | 数据结构,指定在create_symbolic_sizes_strides_storage_offset 中如何创建符号;例如,应为静态还是动态。 |
StatelessSymbolicContext | 通过DimDynamic 和DimConstraint 给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中创建符号。 |
StatefulSymbolicContext | 通过Source:Symbol缓存给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中创建符号。 |
SubclassSymbolicContext | 可追踪张量子类的内部张量的正确符号上下文可能与外部符号上下文不同。 |
DimConstraints | 针对符号维度约束系统的自定义求解器。 |
ShapeEnvSettings | 封装所有可能影响FakeTensor调度的形状环境设置。 |
ConvertIntKey | |
CallMethodKey | |
PropagateUnbackedSymInts | |
DivideByKey | |
InnerTensorKey | |
hint_int | 获取整数的提示值(基于运行时观察到的底层实际值)。 |
is_concrete_int | 检查SymInt底层对象是否为具体值的实用工具。 |
is_concrete_bool | 检查SymBool底层对象是否为具体值的实用工具。 |
is_concrete_float | 检查SymInt底层对象是否为具体值的实用工具。 |
has_free_symbols | bool(free_symbols(val))的快速版本 |
has_free_unbacked_symbols | bool(free_unbacked_symbols(val))的快速版本 |
definitely_true | 仅当能确定a为True时返回True,过程中可能引入守卫。 |
definitely_false | 仅当能确定a为False时返回True,过程中可能引入守卫。 |
guard_size_oblivious | 以无视大小的方式对符号布尔表达式执行守卫。 |
sym_eq | 类似==,但在列表/元组上运行时,会递归测试相等性并使用sym_and连接结果,不引入守卫。 |
constrain_range | 应用约束使传入的SymInt必须在min-max范围内(包含边界),且不引入SymInt的守卫(意味着可用于未绑定的SymInt)。 |
constrain_unify | 给定两个SymInt,约束它们必须相等。 |
canonicalize_bool_expr | 通过将布尔表达式转换为lt/le不等式并将所有非常量项移至右侧,实现规范化。 |
statically_known_true | 如果x可简化为常量且为真,则返回True。 |
lru_cache | |
check_consistent | 测试两个"meta"值(通常为Tensor或SymInt)是否具有相同的值,例如在重追踪后。 |
compute_unbacked_bindings | 在运行fake tensor传播并生成example_value结果后,遍历example_value查找新绑定的未支持符号并记录其路径供后续使用。 |
rebind_unbacked | 假设我们正在重追踪一个已有FX图,该图先前进行过fake tensor传播(因此存在未支持的SymInt)。 |
resolve_unbacked_bindings | |
is_accessor_node |
torch.fx.experimental.proxy_tensor
make_fx | 给定函数f,返回一个新函数。当使用有效参数执行该函数时,会返回一个FX GraphModule,表示执行过程中所执行的操作集合。 |
---|---|
handle_sym_dispatch | 调用当前活动的代理跟踪模式,对操作SymInt/SymFloat/SymBool参数的函数进行符号调度跟踪。 |
get_proxy_mode | 获取当前活动的代理跟踪模式,如果当前未处于跟踪状态则返回None。 |
maybe_enable_thunkify | 在此上下文管理器内,如果正在进行make_fx跟踪,将对所有SymNode计算进行thunkify处理,并避免将其跟踪到图中,除非确实需要。 |
maybe_disable_thunkify | 在某个上下文中禁用thunkification功能。 |
torch.hub
PyTorch Hub 是一个预训练模型仓库,旨在促进研究可复现性。
发布模型
PyTorch Hub 支持通过添加简单的 hubconf.py
文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。
hubconf.py
可以包含多个入口点。每个入口点都定义为 Python 函数(例如:您想发布的预训练模型)。
def entrypoint_name(args, *kwargs):# args & kwargs are optional, for models which take positional/keyword arguments....
如何实现入口点?
以下代码片段展示了如果我们扩展 pytorch/vision/hubconf.py
中的实现,如何为 resnet18
模型指定入口点。在大多数情况下,只需在 hubconf.py
中导入正确的函数就足够了。这里我们使用扩展版本作为示例来说明其工作原理。
完整脚本可查看 pytorch/vision 代码库
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18# resnet18 is the name of entrypoint
def resnet18(pretrained=False, *kwargs):""" # This docstring shows up in hub.help()Resnet18 modelpretrained (bool): kwargs, load pretrained weights into the model"""# Call the model, load pretrained weightsmodel = _resnet18(pretrained=pretrained, *kwargs)return model
dependencies
变量是一个列表,包含加载模型所需的包名。注意这里可能与训练模型所需的依赖项略有不同。args
和kwargs
会传递给实际的可调用函数。- 函数的文档字符串(docstring)将作为帮助信息。它需要说明模型的功能以及允许的位置参数/关键字参数。强烈建议在此处添加几个示例。
- 入口函数可以返回一个模型(nn.Module),也可以返回辅助工具(如分词器)来优化用户工作流程。
- 以下划线开头的可调用对象被视为辅助函数,不会出现在
torch.hub.list()
的返回结果中。 - 预训练权重可以存储在GitHub仓库本地,也可以通过
torch.hub.load_state_dict_from_url()
加载。如果小于2GB,建议将其附加到项目发布版并使用发布版的URL。
在上面的示例中,torchvision.models.resnet.resnet18
处理了 pretrained
参数,你也可以将以下逻辑放在入口函数定义中。
# For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pthdirname = os.path.dirname(__file__)checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)state_dict = torch.load(checkpoint)model.load_state_dict(state_dict)# For checkpoint saved elsewherecheckpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要通知
- 发布的模型至少应位于分支/标签中,不能是随机提交。
从Hub加载模型
PyTorch Hub 提供了一系列便捷的API,帮助开发者探索Hub中所有可用模型:
- 通过
torch.hub.list()
查看所有模型 - 使用
torch.hub.help()
显示文档说明和示例 - 调用
torch.hub.load()
加载预训练模型
torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)
列出由 github
指定的代码仓库中所有可调用的入口点。
参数
github (str)
– 格式为“repo_owner/repo_name[:ref]”的字符串,其中 ref(标签或分支)为可选。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
示例:‘pytorch/vision:0.10’force_reload ([bool], 可选)
– 是否丢弃现有缓存并强制重新下载。默认为False
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 会检查github
参数指定的分支或提交是否确实属于该仓库所有者。此操作会向 GitHub API 发起请求;可通过设置GITHUB_TOKEN
环境变量指定非默认的 GitHub 令牌。默认为False
。trust_repo ([bool],* str 或 *None)
–
"check"
、True
、False
或None
。
此参数在 v1.12 版本引入,用于确保用户仅运行信任的仓库代码。
- 如果为
False
,会提示用户确认是否信任该仓库。 - 如果为
True
,仓库将被添加到信任列表并直接加载,无需明确确认。 - 如果为
"check"
,会检查该仓库是否在缓存的信任列表中。若不存在,则回退到trust_repo=False
的行为。 - 如果为
None
:会发出警告,提示用户将trust_repo
设为False
、True
或"check"
。此选项仅为向后兼容保留,将在 v2.0 版本移除。默认为None
,未来 v2.0 版本将改为默认"check"
。
verbose ([bool], 可选)
– 如果为False
,则屏蔽关于命中本地缓存的消息。注意首次下载的消息无法屏蔽。默认为True
。
返回
可用的可调用入口点列表
返回类型
list
示例
>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)
显示入口点 model
的文档字符串。
参数
github (str)
– 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中 ref(标签或分支)是可选的。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
示例:‘pytorch/vision:0.10’
model (str)
– 仓库hubconf.py
中定义的入口点名称字符串force_reload ([bool], 可选)
– 是否丢弃现有缓存并强制重新下载。默认为False
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 将检查github
参数指定的 ref 是否确实属于该仓库所有者。这将向 GitHub API 发出请求;您可以通过设置GITHUB_TOKEN
环境变量来指定非默认的 GitHub 令牌。默认为False
。trust_repo ([bool], *str 或 *None)
–
"check"
、True
、False
或 None
。
此参数在 v1.12 版本引入,用于确保用户仅运行来自受信任仓库的代码。
-
如果为
False
,将提示用户确认是否信任该仓库。 -
如果为
True
,该仓库将被添加到受信任列表并直接加载,无需明确确认。 -
如果为
"check"
,将检查该仓库是否在缓存的受信任仓库列表中。如果不在列表中,行为将回退到trust_repo=False
选项。 -
如果为
None
:将发出警告,提示用户将trust_repo
设置为False
、True
或"check"
。此选项仅用于向后兼容,将在 v2.0 版本移除。默认为None
,最终将在 v2.0 版本更改为"check"
。
示例
>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)
从 GitHub 仓库或本地目录加载模型。
注意:加载模型是典型用例,但此功能也可用于加载其他对象,如分词器、损失函数等。
如果 source
为 ‘github’,则 repo_or_dir
应为 repo_owner/repo_name[:ref]
格式,其中 ref(标签或分支)为可选项。
如果 source
为 ‘local’,则 repo_or_dir
应为本地目录路径。
参数
repo_or_dir (str)
– 如果source
为 ‘github’,则应为 GitHub 仓库,格式为repo_owner/repo_name[:ref]
(ref 为可选的标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定ref
,则默认分支为main
(如果存在),否则为master
。
如果 source
为 ‘local’,则应为本地目录路径。
model (str)
– 仓库/目录中hubconf.py
文件定义的可调用对象(入口点)名称。*args (可选)
– 可调用对象model
的对应参数。source (str, 可选)
– ‘github’ 或 ‘local’。指定如何解释repo_or_dir
。默认为 ‘github’。trust_repo ([bool],* str 或 *None)
–
"check"
、True
、False
或 None
。
此参数在 v1.12 中引入,用于确保用户仅运行信任仓库中的代码。
-
如果为
False
,将提示用户确认是否信任该仓库。 -
如果为
True
,该仓库将被添加到信任列表并直接加载,无需明确确认。 -
如果为
"check"
,将检查该仓库是否在缓存信任列表中。如果不在,则回退到trust_repo=False
的行为。 -
如果为
None
:将发出警告,提示用户将trust_repo
设置为False
、True
或"check"
。此选项仅用于向后兼容,将在 v2.0 中移除。默认为None
,最终将在 v2.0 中改为"check"
。
force_reload ([bool], 可选)
– 是否无条件强制重新下载 GitHub 仓库。如果source = 'local'
则无效。默认为False
。verbose ([bool], 可选)
– 如果为False
,则屏蔽有关命中本地缓存的消息。注意首次下载的消息无法屏蔽。如果source = 'local'
则无效。默认为True
。skip_validation ([bool], 可选)
– 如果为False
,torchhub 将检查github
参数指定的分支或提交是否属于该仓库所有者。这将向 GitHub API 发出请求;您可通过设置GITHUB_TOKEN
环境变量指定非默认的 GitHub 令牌。默认为False
。**kwargs (可选)
– 可调用对象model
的对应关键字参数。
返回
调用 model
可调用对象时,传入给定 *args
和 **kwargs
的输出。
示例
>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)
将给定URL的对象下载到本地路径。
参数
url (str)
- 要下载对象的URLdst (str)
- 对象将被保存的完整路径,例如/tmp/temporary_file
hash_prefix (str, 可选)
- 如果不为None,下载文件的SHA256哈希值应以hash_prefix
开头。默认值:Noneprogress ([bool], 可选)
- 是否在标准错误输出中显示进度条。默认值:True
示例
>>> torch.hub.download_url_to_file(
... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
... "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)
从给定的URL加载Torch序列化对象。
如果下载的文件是zip压缩包,系统会自动解压。
如果对象已存在于model_dir目录中,则直接反序列化并返回。
model_dir
的默认值为<hub_dir>/checkpoints
,其中hub_dir
是由get_dir()
返回的目录路径。
参数说明
url (str)
- 需要下载对象的URL地址model_dir (str, 可选)
- 保存对象的目录路径map_location (可选)
- 指定存储位置重映射的函数或字典(参见torch.load)progress ([bool], 可选)
- 是否在标准错误输出中显示进度条。默认值:Truecheck_hash ([bool], 可选)
- 若为True,则URL中的文件名部分需遵循命名规范:
filename-<sha256>.ext
,其中<sha256>
是文件内容SHA256哈希值的前8位或更多位数字。该哈希值用于确保唯一文件名并验证文件内容。默认值:Falsefile_name (str, 可选)
- 下载文件的名称。若未设置,则使用URL中的文件名weights_only ([bool], 可选)
- 若为True,则仅加载权重而不加载复杂的pickle对象。建议用于不可信来源。详见load()
说明
返回类型:dict[str, Any]
使用示例
>>> state_dict = torch.hub.load_state_dict_from_url(
... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )
运行加载的模型:
请注意,torch.hub.load()
中的 *args
和 **kwargs
用于实例化模型。加载模型后,如何了解该模型的功能?
建议的工作流程如下:
- 使用
dir(model)
查看模型所有可用的方法 - 通过
help(model.foo)
查看model.foo
运行所需的参数
为了帮助用户无需反复查阅文档即可探索功能,我们强烈建议仓库维护者确保函数帮助信息清晰简洁。同时,提供一个最小可运行示例也非常有帮助。
下载的模型保存在哪里?
模型保存路径按以下顺序确定:
1、调用 hub.set_dir(<PATH_TO_HUB_DIR>)
设置的路径
2、若设置了环境变量 TORCH_HOME
,则使用 $TORCH_HOME/hub
3、若设置了环境变量 XDG_CACHE_HOME
,则使用 $XDG_CACHE_HOME/torch/hub
4、默认路径为 ~/.cache/torch/hub
可通过 torch.hub.get_dir()
获取当前保存路径
***
Get the Torch Hub cache directory used for storing downloaded models \& weights.If [`set_dir()`](https://pytorch.org/docs/stable/data.html#torch.hub.set_dir "torch.hub.set_dir") is not called, default path is `$TORCH_HOME/hub` where
environment variable `$TORCH_HOME` defaults to `$XDG_CACHE_HOME/torch`.
`$XDG_CACHE_HOME` follows the X Design Group specification of the Linux
filesystem layout, with a default value `~/.cache` if the environment
variable is not set.Return typestr torch.hub.set_dir(d)
可选设置用于保存下载模型和权重的 Torch Hub 目录。
参数
d (str)
– 用于保存下载模型和权重的本地文件夹路径。
缓存逻辑
默认情况下,我们在加载文件后不会进行清理。如果文件已存在于 get_dir()
返回的目录中,Hub 会默认使用缓存。
用户可以通过调用 hub.load(..., force_reload=True)
强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重文件,并重新初始化全新下载。当同一分支发布更新时,这个功能非常有用,用户可以及时获取最新版本。
已知限制:
Torch hub 的工作原理是将包当作已安装的包进行导入。Python 导入机制会带来一些副作用,例如你可能会在 Python 缓存 sys.modules
和 sys.path_importer_cache
中看到新增条目,这是 Python 的正常行为。这也意味着,如果不同代码仓库包含同名的子包(通常是名为 model
的子包),在从不同仓库导入不同模型时可能会遇到导入错误。针对这类导入错误的解决方案是从 sys.modules
字典中移除冲突的子包,更多细节可参考 这个 GitHub issue。
需要特别说明的一个已知限制:用户无法在同一个 Python 进程中加载同一代码仓库的两个不同分支。这就像在 Python 中安装两个同名的包一样,是不合理的做法。如果强行尝试,缓存机制可能会介入并带来意外结果。当然,在独立的进程中分别加载它们是完全可行的。
TorchScript
TorchScript 是一种将 PyTorch 代码转换为可序列化和可优化模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在不依赖 Python 的环境中加载运行。
我们提供了一系列工具,帮助开发者逐步将纯 Python 模型转换为可独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中运行。这使得开发者能够继续使用熟悉的 Python 工具训练 PyTorch 模型,然后通过 TorchScript 将模型导出到生产环境。在生产环境中,由于性能和多线程方面的考虑,使用 Python 程序可能并不合适。
如需了解 TorchScript 的入门指南,请参阅 TorchScript 简介教程。
若想查看将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的完整示例,请参考 在 C++ 中加载 PyTorch 模型 教程。
创建 TorchScript 代码
script | 将函数转换为脚本 |
---|---|
trace | 追踪函数并返回一个可执行对象或ScriptFunction ,该对象将通过即时编译进行优化 |
script_if_tracing | 在追踪过程中首次调用时编译fn |
trace_module | 追踪模块并返回一个可执行的ScriptModule ,该模块将通过即时编译进行优化 |
fork | 创建异步任务执行函数,并返回对该执行结果的引用 |
wait | 强制完成torch.jit.Future[T]异步任务,返回任务结果 |
ScriptModule | C++ torch::jit::Module的包装器,包含方法、属性和参数 |
ScriptFunction | 功能上等同于ScriptModule ,但表示单个函数且不包含任何属性或参数 |
freeze | 冻结ScriptModule,将子模块和属性内联为常量 |
optimize_for_inference | 执行一系列优化步骤,为推理目的优化模型 |
enable_onednn_fusion | 根据参数enabled启用或禁用onednn JIT融合 |
onednn_fusion_enabled | 返回onednn JIT融合是否启用 |
set_fusion_strategy | 设置融合过程中可能发生的特化类型和数量 |
strict_fusion | 如果推理中未融合所有节点或训练中未符号微分,则报错 |
save | 保存此模块的离线版本以供其他进程使用 |
load | 加载先前用torch.jit.save 保存的ScriptModule 或ScriptFunction |
ignore | 此装饰器向编译器表明应忽略函数或方法,保留为Python函数 |
unused | 此装饰器向编译器表明应忽略函数或方法,并替换为抛出异常 |
interface | 用于注解不同类型类或模块的装饰器 |
isinstance | 在TorchScript中提供容器类型细化 |
Attribute | 此方法是一个返回值的直通函数,主要用于向TorchScript编译器表明左侧表达式是具有type类型的类实例属性 |
annotate | 用于在TorchScript编译器中指定the_value的类型 |
混合使用追踪与脚本化
在多数情况下,追踪(tracing)或脚本化(scripting)都是将模型转换为 TorchScript 的更简便方式。根据模型不同部分的具体需求,可以组合使用这两种方法。
脚本化函数能够调用追踪生成的函数。当需要在简单前馈模型周围添加控制流逻辑时,这种方式特别有用。例如,序列到序列模型中的束搜索(beam search)通常会用脚本编写,但可以调用通过追踪生成的编码器模块。
示例(在脚本中调用追踪函数):
import torchdef foo(x, y):return 2 * x + ytraced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))@torch.jit.script
def bar(x):return traced_foo(x, x)
被追踪的函数可以调用脚本函数。当模型的大部分只是前馈网络,而其中一小部分需要控制流时,这非常有用。在被追踪函数调用的脚本函数内部,控制流会被正确保留。
示例(在被追踪函数中调用脚本函数):
import torch@torch.jit.script
def foo(x, y):if x.max() y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))
该组合方式同样适用于 nn.Module
,它可以通过追踪生成一个子模块,该子模块可从脚本模块的方法中调用。
示例(使用追踪模块):
import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super().__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(), torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())
TorchScript 语言
TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。详情请参阅完整的 TorchScript 语言参考。
内置函数与模块
TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置功能。完整支持函数列表请参阅 TorchScript 内置函数。
PyTorch 函数与模块
TorchScript 支持 PyTorch 提供的张量和神经网络函数子集。Tensor 上的大多数方法、torch
命名空间中的所有函数、torch.nn.functional
中的全部函数以及 torch.nn
中的大多数模块均可被 TorchScript 支持。
不支持的 PyTorch 函数和模块列表请参阅 TorchScript 不支持的 PyTorch 结构。
Python 函数与模块
TorchScript 支持许多 Python 的内置函数。
math
模块同样受支持(详见数学模块),但其他 Python 模块(无论是内置还是第三方)均不支持。
Python 语言参考对比
如需查看支持的 Python 功能完整列表,请参阅 Python 语言参考覆盖范围。
调试
禁用 JIT 进行调试
PYTORCH_JIT
设置环境变量 PYTORCH_JIT=0
将禁用所有脚本和追踪注解。当您的 TorchScript 模型中出现难以调试的错误时,可以通过此标志强制所有代码以原生 Python 方式运行。由于该标志会禁用 TorchScript(脚本化和追踪),您可以使用诸如 pdb
之类的工具来调试模型代码。例如:
@torch.jit.script
def scripted_fn(x : torch.Tensor):for i in range(12):x = x + xreturn xdef fn(x):x = torch.neg(x)import pdb; pdb.set_trace()return scripted_fn(x)traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))
使用 pdb
调试此脚本时一切正常,但当我们调用 @torch.jit.script
函数时会失效。我们可以全局禁用 JIT 功能,这样就能将 @torch.jit.script
作为普通 Python 函数调用而不进行编译。如果上述脚本名为 disable_jit_example.py
,可以通过以下方式调用:
$ PYTORCH_JIT=0 python disable_jit_example.py
这样我们就能像普通 Python 函数一样单步调试 @torch.jit.script
装饰的函数。如需禁用特定函数的 TorchScript 编译器,请参阅 @torch.jit.ignore
。
代码检查
TorchScript 为所有 ScriptModule
实例提供了代码美化打印器。该美化打印器能够将脚本方法的代码以有效的 Python 语法形式呈现。例如:
@torch.jit.script
def foo(len):# type: (int) -torch.Tensorrv = torch.zeros(3, 4)for i in range(len):if i < 10:rv = rv - 1.0else:rv = rv + 1.0return rvprint(foo.code)
一个包含单个 forward
方法的 ScriptModule
会有一个 code
属性,你可以通过该属性来检查 ScriptModule
的代码。
如果 ScriptModule
包含多个方法,你需要访问方法本身的 .code
属性,而不是模块的。我们可以通过访问 .foo.code
来检查 ScriptModule
上名为 foo
的方法代码。
上面的示例会产生以下输出:
def foo(len: int) -Tensor:rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)rv0 = rvfor i in range(len):if torch.lt(i, 10):rv1 = torch.sub(rv0, 1., 1)else:rv1 = torch.add(rv0, 1., 1)rv0 = rv1return rv0
这是 TorchScript 对 forward
方法代码的编译结果。
您可以通过它来验证 TorchScript(无论是通过追踪还是脚本化方式)是否正确捕获了您的模型代码。
解读图结构
TorchScript 在代码美化打印器之下还有一个更低层次的表示形式,即 IR(中间表示)图。
TorchScript 采用静态单赋值(SSA)中间表示(IR)来描述计算过程。这种格式的指令由 ATen(PyTorch 的 C++ 后端)运算符和其他基础运算符组成,包括用于循环和条件判断的控制流运算符。例如:
@torch.jit.script
def foo(len):# type: (int) -torch.Tensorrv = torch.zeros(3, 4)for i in range(len):if i < 10:rv = rv - 1.0else:rv = rv + 1.0return rvprint(foo.graph)
graph
遵循与代码检查章节中描述的相同规则,涉及 forward
方法查找。
上面的示例脚本生成如下图表:
graph(%len.1 : int):%24 : int = prim::Constant[value=1]()%17 : bool = prim::Constant[value=1]() # test.py:10:5%12 : bool? = prim::Constant()%10 : Device? = prim::Constant()%6 : int? = prim::Constant()%1 : int = prim::Constant[value=3]() # test.py:9:22%2 : int = prim::Constant[value=4]() # test.py:9:25%20 : int = prim::Constant[value=10]() # test.py:11:16%23 : float = prim::Constant[value=1]() # test.py:12:23%4 : int[] = prim::ListConstruct(%1, %2)%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5block0(%i.1 : int, %rv.14 : Tensor):%21 : bool = aten::lt(%i.1, %20) # test.py:11:12%rv.13 : Tensor = prim::If(%21) # test.py:11:9block0():%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18-(%rv.3)block1():%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18-(%rv.6)-(%17, %rv.13)return (%rv)
以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10
为例:
%rv.1 : Tensor
表示我们将输出赋值给一个名为rv.1
的(唯一)值,该值的类型为Tensor
,且我们不知道其具体形状。aten::zeros
是运算符(相当于torch.zeros
),输入列表(%4, %6, %6, %10, %12)
指定了应传入哪些作用域中的值作为输入。像aten::zeros
这样的内置函数的模式可以在内置函数中找到。# test.py:9:10
是生成此指令的原始源文件中的位置。在本例中,它位于名为test.py的文件中,第9行,第10个字符。
注意,运算符也可以关联blocks
,即prim::Loop
和prim::If
运算符。在图形打印输出中,这些运算符的格式会反映其等效的源代码形式,以便于调试。
可以按照所示方式检查图形,以确认由ScriptModule
描述的计算是否正确,无论是自动还是手动方式,如下所述。
追踪器
追踪边界情况
在某些特殊情况下,对给定Python函数/模块的追踪可能无法准确反映底层代码的真实行为。这些情况包括:
- 依赖于输入的控制流追踪(例如张量形状)
- 张量视图就地操作的追踪(例如赋值语句左侧的索引操作)
请注意,这些情况未来实际上可能会变得可追踪。
自动追踪检查
自动捕获追踪中多种错误的一种方法是使用 torch.jit.trace()
API 上的 check_inputs
参数。check_inputs
接收一个由输入元组组成的列表,这些输入将用于重新追踪计算并验证结果。例如:
def loop_in_traced_fn(x):result = x[0]for i in range(x.size(0)):result = result * x[i]return resultinputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)
提供以下诊断信息:
ERROR: Graphs differed across invocations!
Graph diff:graph(%x : Tensor) {%1 : int = prim::Constant[value=0]()%2 : int = prim::Constant[value=0]()%result.1 : Tensor = aten::select(%x, %1, %2)%4 : int = prim::Constant[value=0]()%5 : int = prim::Constant[value=0]()%6 : Tensor = aten::select(%x, %4, %5)%result.2 : Tensor = aten::mul(%result.1, %6)%8 : int = prim::Constant[value=0]()%9 : int = prim::Constant[value=1]()%10 : Tensor = aten::select(%x, %8, %9)- %result : Tensor = aten::mul(%result.2, %10)+ %result.3 : Tensor = aten::mul(%result.2, %10)? ++%12 : int = prim::Constant[value=0]()%13 : int = prim::Constant[value=2]()%14 : Tensor = aten::select(%x, %12, %13)+ %result : Tensor = aten::mul(%result.3, %14)+ %16 : int = prim::Constant[value=0]()+ %17 : int = prim::Constant[value=3]()+ %18 : Tensor = aten::select(%x, %16, %17)- %15 : Tensor = aten::mul(%result, %14)? ^ ^+ %19 : Tensor = aten::mul(%result, %18)? ^ ^- return (%15);? ^+ return (%19);? ^}
这条消息表明,计算过程在我们首次追踪时和使用 check_inputs
进行追踪时出现了差异。实际上,loop_in_traced_fn
函数体中的循环依赖于输入 x
的形状,因此当我们尝试使用不同形状的另一个 x
时,追踪结果就会发生变化。
对于这种情况,可以使用 torch.jit.script()
来捕获此类数据依赖的控制流:
def fn(x):result = x[0]for i in range(x.size(0)):result = result * x[i]return resultinputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:torch.testing.assert_close(fn(input_tuple), scripted_fn(input_tuple))
输出结果为:
graph(%x : Tensor) {%5 : bool = prim::Constant[value=1]()%1 : int = prim::Constant[value=0]()%result.1 : Tensor = aten::select(%x, %1, %1)%4 : int = aten::size(%x, %1)%result : Tensor = prim::Loop(%4, %5, %result.1)block0(%i : int, %7 : Tensor) {%10 : Tensor = aten::select(%x, %1, %i)%result.2 : Tensor = aten::mul(%7, %10)-(%5, %result.2)}return (%result);
}
追踪器警告
追踪器会对追踪计算中的几种问题模式产生警告。例如,假设对一个包含张量切片(视图)进行原地赋值的函数进行追踪:
def fill_row_zero(x):x[0] = torch.rand(x.shape[1:2])return xtraced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
生成多个警告信息和一个直接返回输入数据的图表
fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.x[0] = torch.rand(x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1、of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {return (%0);
}
我们可以通过修改代码来解决这个问题,不再使用原地更新,而是使用torch.cat
来非原地构建结果张量:
def fill_row_zero(x):x = torch.cat((torch.rand(1, x.shape[1:2]), x[1:2]), dim=0)return xtraced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)
常见问题解答
Q: 我想在GPU上训练模型,然后在CPU上进行推理。有哪些最佳实践?
首先将模型从GPU转换到CPU,然后保存它,如下所示:
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")# ... later, when using the model:if use_gpu:model = torch.jit.load("gpu.pt")
else:model = torch.jit.load("cpu.pt")model(input)
推荐采用此方式,因为追踪器可能会观测到张量在特定设备上创建的过程,直接转换已加载的模型可能产生意外效果。在保存模型之前进行类型转换,可确保追踪器获取正确的设备信息。
问:如何在ScriptModule
上存储属性?
假设我们有如下模型:
import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()self.x = 2def forward(self):return self.xm = torch.jit.script(Model())
如果直接实例化 Model
会导致编译错误,因为编译器无法识别 x
。有四种方法可以让编译器识别 ScriptModule
上的属性:
1、nn.Parameter
- 用 nn.Parameter
包装的值会像在 nn.Module
中一样正常工作。
2、register_buffer
- 用 register_buffer
包装的值会像在 nn.Module
中一样正常工作。这相当于一个类型为 Tensor
的属性(见第4点)。
3、常量 - 将类成员标注为 Final
(或在类定义级别将其添加到名为 __constants__
的列表中)会将包含的名称标记为常量。常量会直接保存在模型的代码中。详情请参阅内置常量。
4、属性 - 支持类型的值可以作为可变属性添加。大多数类型可以自动推断,但有些可能需要明确指定,详情请参阅模块属性。
问题:我想追踪模块的方法,但一直遇到这个错误:
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
这个错误通常意味着你正在追踪的方法使用了模块的参数,而你传递的是模块的方法而不是模块实例(例如 my_module_instance.forward
与 my_module_instance
)。
- 使用模块的方法调用
trace
会将模块参数(可能需要梯度)捕获为常量。 - 另一方面,使用模块实例(例如
my_module
)调用trace
会创建一个新模块,并正确地将参数复制到新模块中,因此如果需要,它们可以累积梯度。
要追踪模块上的特定方法,请参阅 torch.jit.trace_module
。
已知问题
当你在 TorchScript 中使用 Sequential
时,某些 Sequential
子模块的输入可能会被错误推断为 Tensor
,即使它们被标注为其他类型。标准解决方案是继承 nn.Sequential
并重新声明 forward
方法,确保输入类型正确。
附录
迁移至 PyTorch 1.2 递归脚本化 API
本节详细说明 PyTorch 1.2 中 TorchScript 的变化。如果你是 TorchScript 的新用户,可以跳过这部分内容。PyTorch 1.2 对 TorchScript API 主要做了两处改动:
1、torch.jit.script
现在会尝试递归编译遇到的函数、方法和类。一旦调用 torch.jit.script
,编译过程将采用"默认启用"而非"手动启用"机制。
2、torch.jit.script(nn_module_instance)
现已成为创建 ScriptModule
的推荐方式,取代了原先继承 torch.jit.ScriptModule
的做法。
这些改动共同提供了一个更简单易用的 API,用于将你的 nn.Module
转换为可优化并在非 Python 环境中执行的 ScriptModule
。
新的使用方式如下:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))my_model = Model()
my_scripted_model = torch.jit.script(my_model)
- 模块的
forward
方法默认会被编译。从forward
中调用的方法会按它们在forward
中的使用顺序进行惰性编译。 - 若要编译未被
forward
调用的其他方法,需添加@torch.jit.export
装饰器。 - 如需阻止编译器编译某个方法,可添加
@torch.jit.ignore
或@torch.jit.unused
。@ignore
会保留 Python 方法调用,而@unused
会将其替换为异常。@ignored
方法不可导出;@unused
方法可以导出。 - 大多数属性类型可自动推断,因此无需使用
torch.jit.Attribute
。对于空容器类型,建议使用 PEP 526 风格 的类注解来声明类型。 - 常量可通过
Final
类注解标记,无需将成员名加入__constants__
列表。 - 可用 Python 3 类型提示替代
torch.jit.annotate
函数。
基于这些变更,以下内容已被弃用,新代码中不应继续使用:
@torch.jit.script_method
装饰器- 继承自
torch.jit.ScriptModule
的类 torch.jit.Attribute
包装类__constants__
数组torch.jit.annotate
函数
模块
警告:在 PyTorch 1.2 中,@torch.jit.ignore
注解的行为发生了变化。在 PyTorch 1.2 之前,@ignore
装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()
。现在 @torch.jit.ignore
等同于 @torch.jit.ignore(drop=False)
。详情请参阅 @torch.jit.ignore
和 @torch.jit.unused
。
当传递给 torch.jit.script
函数时,torch.nn.Module
的数据会被复制到 ScriptModule
中,并由 TorchScript 编译器编译该模块。默认情况下,模块的 forward
方法会被编译。从 forward
调用的方法会按照它们在 forward
中的使用顺序延迟编译,同时也会编译任何带有 @torch.jit.export
注解的方法。
torch.jit.export(fn)
这个装饰器用于标记nn.Module
中的某个方法作为ScriptModule
的入口点,该方法将被编译。
forward
方法默认被视为入口点,因此不需要此装饰器。
从forward
调用的函数和方法会在编译器处理时自动编译,所以它们也不需要这个装饰器。
示例(在方法上使用@torch.jit.export
装饰器):
import torch
import torch.nn as nnclass MyModule(nn.Module):def implicitly_compiled_method(self, x):return x + 99# `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effectdef forward(self, x):return x + 10@torch.jit.exportdef another_forward(self, x):# When the compiler sees this call, it will compile# `implicitly_compiled_method`return self.implicitly_compiled_method(x)def unused_method(self, x):return x - 20# `m` will contain compiled methods:
# `forward`
# `another_forward`
# `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())
函数
函数基本保持不变,必要时可以使用 @torch.jit.ignore
或 torch.jit.unused
装饰器进行修饰。
# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():return 2# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():return 2# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():import pdb; pdb.set_trace()return 4# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():return 2
TorchScript 类
警告:TorchScript 类的支持目前处于实验阶段。当前最适合用于简单的记录式类型(可理解为附加了方法的NamedTuple
)。
用户定义的 TorchScript 类中所有内容默认会被导出,如有需要可以使用 @torch.jit.ignore
装饰器来忽略特定函数。
属性
TorchScript 编译器需要知道模块属性的类型。大多数类型可以通过成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格 的类注解显式标注类型。如果某个类型既无法推断又未显式标注,则不会将其作为属性添加到最终的 ScriptModule
中。
旧版 API:
from typing import Dict
import torchclass MyModule(torch.jit.ScriptModule):def __init__(self):super().__init__()self.my_dict = torch.jit.Attribute({}, Dict[str, int])self.my_int = torch.jit.Attribute(20, int)m = MyModule()
新API:
from typing import Dictclass MyModule(torch.nn.Module):my_dict: Dict[str, int]def __init__(self):super().__init__()# This type cannot be inferred and must be specifiedself.my_dict = {}# The attribute type here is inferred to be `int`self.my_int = 20def forward(self):passm = torch.jit.script(MyModule())
常量
Final
类型构造器可用于将成员标记为常量。如果成员未被标记为常量,它们将被复制到生成的 ScriptModule
中作为属性。使用 Final
可以在已知值固定的情况下开启优化机会,并提供额外的类型安全性。
旧版 API:
class MyModule(torch.jit.ScriptModule):__constants__ = ['my_constant']def __init__(self):super().__init__()self.my_constant = 2def forward(self):pass
m = MyModule()
新 API:
from typing import Finalclass MyModule(torch.nn.Module):my_constant: Final[int]def __init__(self):super().__init__()self.my_constant = 2def forward(self):passm = torch.jit.script(MyModule())
变量
容器默认具有 Tensor
类型且不可为空(更多信息请参阅默认类型章节)。之前使用 torch.jit.annotate
来告知 TorchScript 编译器类型信息,现在已支持 Python 3 风格的类型提示。
import torch
from typing import Dict, Optional@torch.jit.script
def make_dict(flag: bool):x: Dict[str, int] = {}x['hi'] = 2b: Optional[int] = Noneif flag:b = 2return x, b
融合后端
TorchScript 执行优化提供了几种融合后端选择。CPU 上的默认融合器是 NNC,它支持 CPU 和 GPU 的融合操作。而 GPU 上的默认融合器是 NVFuser,它支持更广泛的运算符,并已证明能生成具有更高吞吐量的内核。有关使用和调试的更多细节,请参阅 NVFuser 文档。
参考资料
- Python 语言参考覆盖范围
- TorchScript 不支持的 PyTorch 结构
torch.linalg
常用线性代数运算。
有关常见数值边界情况的说明,请参阅线性代数 (torch.linalg)。
矩阵属性
norm | 计算向量或矩阵范数 |
---|---|
vector_norm | 计算向量范数 |
matrix_norm | 计算矩阵范数 |
diagonal | torch.diagonal() 的别名,默认参数为dim1 = -2, dim2 = -1 |
det | 计算方阵的行列式 |
slogdet | 计算方阵行列式绝对值的符号和自然对数 |
cond | 计算矩阵关于某个矩阵范数的条件数 |
matrix_rank | 计算矩阵的数值秩 |
矩阵分解
cholesky | 计算复数厄米特矩阵或实数对称正定矩阵的Cholesky分解 |
---|---|
qr | 计算矩阵的QR分解 |
lu | 计算带部分主元消去的矩阵LU分解 |
lu_factor | 计算带部分主元消去的矩阵LU分解的紧凑表示形式 |
eig | 计算方阵的特征值分解(如果存在) |
eigvals | 计算方阵的特征值 |
eigh | 计算复数厄米特矩阵或实数对称矩阵的特征值分解 |
eigvalsh | 计算复数厄米特矩阵或实数对称矩阵的特征值 |
svd | 计算矩阵的奇异值分解(SVD) |
svdvals | 计算矩阵的奇异值 |
求解器
solve | 计算具有唯一解的线性方程组的解。 |
---|---|
solve_triangular | 计算具有唯一解的三角线性方程组的解。 |
lu_solve | 在给定LU分解的情况下,计算具有唯一解的线性方程组的解。 |
lstsq | 计算线性方程组的最小二乘解。 |
逆矩阵
inv | 计算方阵的逆矩阵(如果存在)。 |
---|---|
pinv | 计算矩阵的伪逆(Moore-Penrose 逆)。 |
矩阵函数
matrix_exp | 计算方阵的矩阵指数 |
---|---|
matrix_power | 计算方阵的整数n次幂 |
矩阵运算
cross | 计算两个三维向量的叉积 |
---|---|
matmul | torch.matmul() 的别名 |
vecdot | 沿指定维度计算两批向量的点积 |
multi_dot | 通过优化乘法顺序来高效计算两个及以上矩阵的连乘,实现最少算术运算 |
householder_product | 计算Householder矩阵乘积的前n列 |
张量运算
tensorinv | 计算 torch.tensordot() 的乘法逆元。 |
---|---|
tensorsolve | 计算方程组 torch.tensordot(A, X) = B 的解 X。 |
杂项函数
vander | 生成范德蒙矩阵 |
---|
实验性函数
cholesky_ex | 计算复厄米特矩阵或实对称正定矩阵的Cholesky分解。 |
---|---|
inv_ex | 计算可逆方阵的逆矩阵。 |
solve_ex | solve() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
lu_factor_ex | 这是 lu_factor() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
ldl_factor | 计算厄米特矩阵或对称矩阵(可能不定)的LDL分解的紧凑表示。 |
ldl_factor_ex | 这是 ldl_factor() 的一个变体,除非 check_errors = True,否则不执行错误检查。 |
ldl_solve | 使用LDL分解计算线性方程组的解。 |
torch.monitor
警告:本模块为原型版本,其接口和功能可能在未来的PyTorch版本中未经通知即发生变更。
torch.monitor
提供了从PyTorch记录事件和计数器的接口。
统计接口设计用于追踪高层次指标,这些指标会定期记录以监控系统性能。由于统计数据会按特定窗口大小进行聚合,您可以在关键循环中记录它们而对性能影响极小。
对于不频繁发生的事件或数值(如损失值、准确率、使用情况追踪),可以直接使用事件接口。
可以注册事件处理器来处理事件,并将其传递至外部事件接收器。
API 参考
class torch.monitor.Aggregation
以下是可用的统计聚合类型:
成员说明:
VALUE :VALUE 返回最后添加的值。
MEAN :MEAN 计算所有添加值的算术平均值。
COUNT :COUNT 返回已添加值的总数量。
SUM :SUM 返回所有添加值的总和。
MAX :MAX 返回添加值中的最大值。
MIN :MIN 返回添加值中的最小值。
property name
class torch.monitor.Stat
Stat 用于在固定时间间隔内高效计算汇总统计量。Stat 会每隔 window_size
时长将统计结果记录为一个 Event 事件。当时间窗口关闭时,统计结果会通过事件处理器以 torch.monitor.Stat
事件的形式记录。
建议将 window_size
设置为较高的值(例如 60 秒),以避免记录过多事件。Stat 使用毫秒级精度。
如果设置了 max_samples
参数,Stat 会通过丢弃超出限制的 add
调用来限制每个窗口的最大样本数。未设置该参数时,窗口期内所有的 add
调用都会被纳入统计。这个可选字段主要用于在样本量可能波动的情况下,使不同窗口期的聚合数据更具可比性。
当 Stat 对象被销毁时,即使当前时间窗口尚未结束,它也会记录所有剩余数据。
__init__(self: torch._C._monitor.Stat, name: str , aggregations: list [torch._C._monitor.Aggregation], window_size: [datetime.timedelta, max_samples: int = 9223372036854775807) → None
构造 Stat
对象。
add(self: torch._C._monitor.Stat, v: float) → None
Adds a value to the stat to be aggregated according to the configured stat type and aggregations.
property count
Number of data points that have currently been collected. Resets
once the event has been logged.
get(self: torch._C._monitor.Stat) → dict[torch._C._monitor.Aggregation, float]
Returns the current value of the stat, primarily for testing purposes. If the stat has logged and no additional values have been added this will be zero.
property name
The name of the stat that was set during creation.
class torch.monitor.data_value_t
data_value_t is one of str
, float
, int
, bool
.
class torch.monitor.Event
Event represents a specific typed event to be logged. This can represent high-level data points such as loss or accuracy per epoch or more low-level aggregations such as through the Stats provided through this library.
All Events of the same type should have the same name so downstream
handlers can correctly process them.
__init__(self: torch._C._monitor.Event, name: str, timestamp: datetime.datetime, data: dict[str, data_value_t]) → None
Constructs the Event
.
property data
The structured data contained within the Event
.
property name
The name of the Event
.
property timestamp
The timestamp when the Event
happened.
class torch.monitor.EventHandlerHandle
EventHandlerHandle is a wrapper type returned by register_event_handler
used to unregister the handler via unregister_event_handler
. This cannot be directly initialized.
torch.monitor.log_event(event: torch._C._monitor.Event) → None
log_event logs the specified event to all of the registered event handlers. It’s up to the event handlers to log the event out to the corresponding event sink.
If there are no event handlers registered this method is a no-op.
torch.monitor.register_event_handler(callback: Callable[[torch._C._monitor.Event], None ]) → torch._C._monitor.EventHandlerHandl)
register_event_handler registers a callback to be called whenever an event is logged via log_event
. These handlers should avoid blocking the main thread since that may interfere with training as they run during the log_event
call.
torch.monitor.unregister_event_handler(handler: torch._C._monitor.EventHandlerHandl)) → None
unregister_event_handler unregisters the EventHandlerHandle
returned after calling register_event_handler
. After this returns the event handler will no longer receive events.
class torch.monitor.TensorboardEventHandler(writer)
TensorboardEventHandler is an event handler that will write known events to the provided SummaryWriter.
This currently only supports torch.monitor.Stat
events which are logged as scalars.
Example :
>>> from torch.utils.tensorboard import SummaryWriter>>> from torch.monitor import TensorboardEventHandler, register_event_handler>>> writer = SummaryWriter("log_dir")>>> register_event_handler(TensorboardEventHandler(writer))
__init__(writer)
构建 TensorboardEventHandler
。
torch.signal 模块
torch.signal 模块的设计灵感来源于 SciPy 的 signal 模块。
torch.signal.windows 窗口函数
bartlett | 计算巴特利特(Bartlett)窗口 |
---|---|
blackman | 计算布莱克曼(Blackman)窗口 |
cosine | 计算余弦窗口,实现方式与SciPy保持一致 |
exponential | 计算指数窗口 |
gaussian | 计算高斯窗口 |
general_cosine | 计算广义余弦窗口 |
general_hamming | 计算广义汉明(Hamming)窗口 |
hamming | 计算汉明(Hamming)窗口 |
hann | 计算汉宁(Hann)窗口 |
kaiser | 计算凯撒(Kaiser)窗口 |
nuttall | 根据Nuttall方法计算最小4项布莱克曼-哈里斯(Blackman-Harris)窗口 |
torch.special
torch.special 模块的设计灵感来源于 SciPy 的 special 模块。
函数
torch.special.airy_ai(input, *, out=None) → Tensor
Airy 函数 Ai(input)Ai(input)Ai(input)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
torch.special.bessel_j0(input, *, out=None) → Tensor
第一类零阶贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选 )
– 输出张量。
torch.special.bessel_j1(input, *, out=None) → Tensor
第一类111阶贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
torch.special.digamma(input, *, out=None) → Tensor
计算输入张量的伽玛函数的对数导数。
ϝ(x)=ddxln(Γ(x))=Γ′(x)Γ(x)\digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)}ϝ(x)=dxdln(Γ(x))=Γ(x)Γ′(x)
参数
input ( Tensor )
– 用于计算digamma函数的输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
注意:此函数与SciPy的scipy.special.digamma功能相似。
注意:从PyTorch 1.8开始,digamma函数在输入为0时会返回-Inf,而此前版本会返回NaN。
示例:
>>> a = torch.tensor([1, 0.5])
>>> torch.special.digamma(a)
tensor([-0.5772, -1.9635])
torch.special.entr(input, *, out=None) → Tensor
计算输入张量 input
中各元素的熵(定义如下)。
$$
\begin{align}
\text{entr(x)} = \begin{cases}
-x * \ln(x) & x 0 \
0 & x = 0.0 \
-\infty & x < 0
\end{cases}
\end{align}
$$
参数说明
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> a = torch.arange(-0.5, 1, 0.5)
>>> a tensor([-0.5000, 0.0000, 0.5000])
>>> torch.special.entr(a)
tensor([-inf, 0.0000, 0.3466])
torch.special.erf(input, *, out=None) → Tensor
计算输入张量的误差函数。误差函数定义如下:
$$\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([0.0000, -0.8427, 1.0000])
torch.special.erfc(input, *, out=None) → Tensor
计算输入张量的互补误差函数。
互补误差函数的定义如下:
$$
\mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([1.0000, 1.8427, 0.0000])
torch.special.erfcx(input, *, out=None) → Tensor
计算input
中每个元素的缩放互补误差函数。
缩放互补误差函数的定义如下:
erfcx(x)=ex2erfc(x)\mathrm{erfcx}(x) = e^{x^2} erfc(x) erfcx(x)=ex2erfc(x)
erfcx(x)=ex2erfc(x)
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例
>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
tensor([1.0000, 5.0090, 0.0561])
torch.special.erfinv(input, *, out=None) → Tensor
计算输入张量的反误差函数值。
反误差函数在区间 (−1,1) 内定义为:
erfinv(erf(x))=x\mathrm{erfinv(erf(x))}= x erfinv(erf(x))=x
参数说明
input ( Tensor )
- 输入张量
关键字参数
out ( Tensor , 可选)
- 输出张量
使用示例
>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([0.0000, 0.4769, -inf])
torch.special.exp2(input, *, out=None) → Tensor
计算 input
的以 2 为底的指数函数。
yi=2xiy_{i} = 2^{x_{i}}yi=2xi
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
示例:
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([1., 2., 8., 16.])
torch.special.expit(input, *, out=None) → Tensor
计算输入张量 input
各元素的 expit 值(也称为 logistic sigmoid 函数)。
outi=11+ei−input{out}_{i} = \frac{1}{1 + e^{-input}_{i}} outi=1+ei−input1
outi=1+e−inputi1
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例:
>>> t = torch.randn(4)
>>> t
tensor([0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t)
tensor([0.7153, 0.7481, 0.2920, 0.1458])
torch.special.expm1(input, *, out=None) → Tensor
计算输入张量input
各元素的指数值减1。
yi=exi−1y_{i} = e^{x_{i}} - 1yi=exi−1
注意:对于较小的x值,该函数比直接计算exp(x) - 1能提供更高的精度。
参数说明
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , 可选)
- 输出张量。
使用示例
>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([0., 1.])
torch.special.gammainc(input, other, *, out=None) → Tensor
计算正则化的下不完全伽马函数:
$$
\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
$$
其中inputiinput_iinputi和otheriother_iotheri均为弱正数且至少有一个严格为正数。若两者均为零或任一为负数,则outi=nan\text{out}_i=\text{nan}outi=nan。上述公式中的Γ\GammaΓ表示伽马函数:
$$
\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
$$
相关函数请参阅torch.special.gammaincc()
和torch.special.gammaln()
。
支持广播至通用形状和浮点输入。
注意:目前不支持对input
的反向传播。如需此功能,请在PyTorch的Github上提交issue。
参数
input ( Tensor )
– 第一个非负输入张量other ( Tensor )
– 第二个非负输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
示例:
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.3528, 0.5665, 0.7350])
tensor([0.3528, 0.5665, 0.7350])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
torch.special.gammaincc(input, other, *, out=None) → Tensor
计算正则化上不完全伽马函数:
$$\text{out}_{i} = \frac{1}{\Gamma(\text{input}i)} \int{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt
$$
其中 inputiinput_iinputi 和 otheriother_iotheri 均为弱正数,且至少有一个严格为正数。若两者均为零或任一为负数,则 outi=nanout_i=nanouti=nan。上式中的 Γ(⋅)\Gamma(\cdot)Γ(⋅) 表示伽马函数,
$$\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
$$
相关函数请参阅 torch.special.gammainc()
和 torch.special.gammaln()
。
支持广播至相同形状及浮点输入。
注意:目前不支持对 input
的反向传播。如需此功能,请在 PyTorch 的 Github 上提交 issue。
参数
input ( Tensor )
– 第一个非负输入张量other ( Tensor )
– 第二个非负输入张量
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.6472, 0.4335, 0.2650])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])
torch.special.gammaln(input, *, out=None) → Tensor
计算输入张量绝对值的伽玛函数的自然对数。
outi=lnΓ(∣inputi∣)\text{out}{i} = \ln \Gamma(|\text{input}{i}|)
outi=lnΓ(∣inputi∣)
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , optional)
- 输出张量。
示例
>>> a = torch.arange(0.5, 2, 0.5)
>>> torch.special.gammaln(a)
tensor([0.5724, 0.0000, -0.1208])
torch.special.i0(input, *, out=None) → Tensor
计算input
中每个元素的第一类零阶修正贝塞尔函数。
$$\text{out}{i} = I_0(\text{input}{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}
$$
参数
input ( Tensor )
- 输入张量
关键字参数
out ( Tensor , 可选)
- 输出张量
示例
>>> torch.i0(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
torch.special.i0e(input, *, out=None) → Tensor
为 input
的每个元素计算指数缩放的第一类零阶修正贝塞尔函数(定义如下)。
$$\text{out}{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
torch.special.i1(input, *, out=None) → Tensor
计算input
中每个元素的一阶第一类修正贝塞尔函数(定义如下)。
$$\text{out}_{i} = \exp(-|x|) * i1(x) =
\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例
>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
torch.special.i1e(input, *, out=None) → Tensor
计算input
中每个元素的指数缩放一阶第一类修正贝塞尔函数(定义如下):
outi=exp(−∣x∣)∗i1(x)=exp(−∣x∣)∗(inputi)2∗∑k=0∞(inputi2/4)k(k!)∗(k+1)!\text{out}_{i} = \exp(-|x|) * i1(x) =
\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}
outi=exp(−∣x∣)∗i1(x)=exp(−∣x∣)∗2(inputi)∗k=0∑∞(k!)∗(k+1)!(inputi2/4)k
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
torch.special.log1p(input, *, out=None) → Tensor
torch.log1p()
的别名。
torch.special.log_ndtr(input, *, out=None) → Tensor
计算标准高斯概率密度函数从负无穷到input
的逐元素积分对数。
$$\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
$$
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
示例
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
torch.special.log_softmax(input, dim, *, dtype=None) → Tensor
计算经过对数处理的softmax结果。
虽然在数学上等价于log(softmax(x)),但分开执行这两个操作会更慢且数值不稳定。该函数的计算方式如下:
log_softmax(xi)=log(exp(xi)∑jexp(xj))\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
log_softmax(xi)=log(∑jexp(xj)exp(xi))
参数说明
input ( Tensor )
– 输入张量dim ( int )
– 指定计算log_softmax的维度dtype (
torch.dtype, optional)
– 返回张量的期望数据类型
如果指定该参数,在执行操作前会将输入张量转换为dtype
类型。这有助于防止数据类型溢出。默认值:None。
使用示例:
>>> t = torch.ones(2, 2)
>>> torch.special.log_softmax(t, 0)
tensor([[-0.6931, -0.6931], [-0.6931, -0.6931]])
torch.special.logit(input, eps=None, *, out=None) → Tensor
返回一个包含input
元素logit值的新张量。
当eps不为None时,input
会被截断到[eps, 1 - eps]区间。
当eps为None且input
< 0或input
> 1时,函数将返回NaN。
yi=ln(zi1−zi)zi={xiif eps is Noneepsif xi<epsxiif eps≤xi≤1−eps1−epsif xi>1−eps\begin{align} y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\ z_{i} &= \begin{cases} x_{i} & \text{if eps is None} \\ \text{eps} & \text{if } x_{i} < \text{eps} \\ x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\ 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps} \end{cases} \end{align}yizi=ln(1−zizi)=⎩⎨⎧xiepsxi1−epsif eps is Noneif xi<epsif eps≤xi≤1−epsif xi>1−eps
参数
input ( Tensor )
- 输入张量。eps (float, 可选)
- 用于输入截断的epsilon值。默认值:None
关键字参数
out ( Tensor , 可选)
- 输出张量。
示例:
>>> a = torch.rand(5)
>>> a tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
torch.special.logsumexp(input, dim, keepdim=False, *, out=None)
torch.logsumexp()
的别名。
torch.special.multigammaln(input, p, *, out=None) → Tensor
计算给定维度 p 的多元对数伽玛函数,按元素逐个计算,公式如下:
$$\log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)
$$
其中 C=log(π)⋅p(p−1)4C = \log(\pi) \cdot \frac{p (p - 1)}{4}C=log(π)⋅4p(p−1) 为伽玛函数。
所有元素必须大于 p−12\frac{p - 1}{2}2p−1,否则行为未定义。
参数
input ( Tensor )
- 用于计算多元对数伽玛函数的张量p ( int )
- 维度数量
关键字参数
out ( Tensor , optional)
- 输出张量
示例
>>> a = torch.empty(2, 3).uniform_(1, 2)
>>> a tensor([[1.6835, 1.8474, 1.1929], [1.0475, 1.7162, 1.4180]])
>>> torch.special.multigammaln(a, 2)
tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]])
torch.special.ndtr(input, *, out=None) → Tensor
计算标准高斯概率密度函数从负无穷到输入值input
的逐元素积分面积。
ndtr(x)=12π∫−∞xe−12t2dt\text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dtndtr(x)=2π1∫−∞xe−21t2dt
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例
>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
torch.special.ndtri(input, *, out=None) → Tensor
计算高斯概率密度函数下(从负无穷积分到x)面积等于input
各元素值的对应参数x。
ndtri(p)=2erf−1(2p−1)\text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1)ndtri(p)=2erf−1(2p−1)
注意:也称为正态分布的分位数函数。
参数说明:
input ( Tensor )
- 输入张量
关键字参数:
out ( Tensor , optional)
- 输出张量
示例:
>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
torch.special.polygamma(n, input, *, out=None) → Tensor
计算输入张量 input
的 digamma 函数的 nthn^{th}nth 阶导数。
其中 n≥0n \geq 0n≥0 称为多伽马函数的阶数。
$$\psi^{(n)}(x) = \frac{d{(n)}}{dx{(n)}} \psi(x)
$$
注意:此函数仅针对非负整数 n≥0 实现。
参数说明
n ( int )
– 多伽马函数的阶数input ( Tensor )
– 输入张量
关键字参数
out ( Tensor , optional)
– 输出张量
使用示例:
>>> a = torch.tensor([1, 0.5])
>>> torch.special.polygamma(1, a)
tensor([1.64493, 4.9348])
>>> torch.special.polygamma(2, a)
tensor([-2.4041, -16.8288])
>>> torch.special.polygamma(3, a)
tensor([6.4939, 97.4091])
>>> torch.special.polygamma(4, a)
tensor([-24.8863, -771.4742])
torch.special.psi(input, *, out=None) → Tensor
torch.special.digamma()
的别名。
torch.special.round(input, *, out=None) → Tensor
torch.round()
的别名。
torch.special.scaled_modified_bessel_k0(input, *, out=None) → Tensor
二阶修正贝塞尔函数(阶数为0)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选)
– 输出张量。
torch.special.scaled_modified_bessel_k1(input, *, out=None) → Tensor
第二类111阶缩放修正贝塞尔函数。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , 可选 )
– 输出张量。
torch.special.sinc(input, *, out=None) → Tensor
计算 input
的归一化 sinc 函数值。
$$\text{out}_{i} =
\begin{cases}
1, & \text{if}\ \text{input}_{i}=0 \
\sin(\pi \text{input}{i}) / (\pi \text{input}{i}), & \text{otherwise}
\end{cases}
$$
参数
input ( Tensor )
- 输入张量。
关键字参数
out ( Tensor , optional)
- 输出张量。
示例
>>> t = torch.randn(4)
>>> t
tensor([0.2252, -0.2948, 1.0267, -1.1566])
>>> torch.special.sinc(t)
tensor([0.9186, 0.8631, -0.0259, -0.1300])
torch.special.softmax(input, dim, *, dtype=None) → Tensor
计算softmax函数。
Softmax的定义如下:
Softmax(xi)=exp(xi)∑jexp(xj)\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}Softmax(xi)=∑jexp(xj)exp(xi)
该函数会沿着指定维度dim对所有切片进行计算,并将结果重新缩放,使元素值落在[0, 1]区间且总和为1。
参数
input ( Tensor )
– 输入张量dim ( int )
– 指定计算softmax的维度dtype (
torch.dtype, 可选)
– 返回张量的期望数据类型
若指定该参数,在执行操作前会将输入张量转换为dtype
类型。这有助于防止数据类型溢出。默认值:None。
示例::
>>> t = torch.ones(2, 2)
>>> torch.special.softmax(t, 0)
tensor([[0.5000, 0.5000], [0.5000, 0.5000]])
torch.special.spherical_bessel_j0(input, *, out=None) → Tensor
一阶球面贝塞尔函数(阶数为000)。
参数
input ( Tensor )
– 输入张量。
关键字参数
out ( Tensor , optional)
– 输出张量。
torch.special.xlog1py(input, other, *, out=None) → Tensor
计算 input * log1p(other)
,具体分为以下几种情况:
outi={NaNif otheri=NaN0if inputi=0.0and otheri!=NaNinputi∗log1p(otheri)otherwise\text{out}_{i} = \begin{cases} \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\ \text{input}_{i} * \text{log1p}(\text{other}_{i}) & \text{otherwise} \end{cases} outi=⎩⎨⎧NaN0inputi∗log1p(otheri)if otheri=NaNif inputi=0.0 and otheri!=NaNotherwise
与 SciPy 的 scipy.special.xlog1py
功能类似。
参数
input (Number* 或 Tensor)
– 乘数other (Number* 或 Tensor)
– 参数
注意:input
和 other
中至少有一个必须是张量。
关键字参数
out (Tensor, 可选)
– 输出张量。
示例:
>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlog1py(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlog1py(x, y)
tensor([1.3863, 2.1972, 2.0794])
>>> torch.special.xlog1py(x, 4)
tensor([1.6094, 3.2189, 4.8283])
>>> torch.special.xlog1py(2, y)
tensor([2.7726, 2.1972, 1.3863])
torch.special.xlogy(input, other, *, out=None) → Tensor
计算 input * log(other)
,具体有以下几种情况:
outi={NaN若 otheri=NaN0若 inputi=0.0inputi∗log(otheri)其他情况\text{out}_{i} = \begin{cases}
\text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
0 & \text{if } \text{input}_{i} = 0.0 \\
\text{input}{i} * \log{(\text{other}{i})} & \text{otherwise}
\end{cases}
outi=⎩⎨⎧NaN0inputi∗log(otheri)若 otheri=NaN若 inputi=0.0其他情况类似于 SciPy 的 scipy.special.xlogy 函数。
参数
input (Number* 或 Tensor)
– 乘数other (Number* 或 Tensor)
– 参数
注意:input
和 other
中至少有一个必须是张量。
关键字参数
out (Tensor, 可选)
– 输出张量。
示例:
>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlogy(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlogy(x, y)
tensor([1.0986, 1.3863, 0.0000])
>>> torch.special.xlogy(x, 4)
tensor([1.3863, 2.7726, 4.1589])
>>> torch.special.xlogy(2, y)
tensor([2.1972, 1.3863, 0.0000])
torch.special.zeta(input, other, *, out=None) → Tensor
逐元素计算 Hurwitz zeta 函数。
$$\zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}
$$
参数
input ( Tensor )
– 对应 x 的输入张量。other ( Tensor )
– 对应 q 的输入张量。
注意:当 q = 1 时即为黎曼 zeta 函数。
关键字参数
out ( Tensor , optional)
– 输出张量。
示例:
>>> x = torch.tensor([2., 4.])
>>> torch.special.zeta(x, 1)
tensor([1.6449, 1.0823])
>>> torch.special.zeta(x, torch.tensor([1., 2.]))
tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449])
torch.overrides
该模块提供了多种辅助函数,用于支持__torch_function__
协议。有关__torch_function__
协议的更多详细信息,请参阅扩展torch Python API。
函数
torch.overrides.get_ignored_functions()
返回无法被__torch_function__
覆盖的公共函数。
返回值:一个包含torch API中公开但无法通过__torch_function__
覆盖的函数的元组。这主要是因为这些函数的参数都不是张量或类张量对象。
返回类型:set[Callable]
示例
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions()
可通过 __torch_function__
重写的函数列表
返回值:一个字典,将包含可重写函数的命名空间映射到该命名空间中可被重写的函数。
返回类型:Dict[Any, List[Callable]]
torch.overrides.resolve_name(f)
获取传递给__torch_function__
的函数的人类可读字符串名称
参数
f (Callable)
– 需要解析名称的函数。
返回值:该函数的名称;如果对其进行求值,应能返回输入函数。
返回类型:str
torch.overrides.get_testing_overrides()
返回一个包含所有可覆盖函数的虚拟重载的字典
返回值:一个字典,将 PyTorch API 中的可覆盖函数映射到具有相同签名的 lambda 函数,这些 lambda 函数无条件返回 -1。这些 lambda 函数对于测试定义了 __torch_function__
类型的 API 覆盖率非常有用。
返回类型:Dict[Callable, Callable]
示例
>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)
实现一个检查__torch_function__
重载的函数。
在C++实现中,与此函数等效的是torch::autograd::handle_torch_function。
参数
public_api (function)
- 最初以public_api(args, *kwargs)
形式调用的公开torch API函数,现在正在检查其参数。relevant_args (iterable)
- 需要检查__torch_function__
方法的参数迭代器。args (tuple)
- 最初传入public_api
的任意位置参数。kwargs (tuple)
- 最初传入public_api
的任意关键字参数。
返回
根据情况返回调用implementation
或__torch_function__
方法的结果。
返回类型
object
:raises TypeError: 如果找不到实现。
示例
>>> def func(a):
... if has_torch_function_unary(a):
... return handle_torch_function(func, (a,), a)
... return a + 0
torch.overrides.has_torch_function()
检查可迭代对象中的元素是否实现了__torch_function__
,或者是否启用了__torch_function__
模式。注意:精确的Tensor
和Parameter
被视为不可调度类型。此方法用于保护对handle_torch_function()
的调用,不要用它来检测对象是否类似Tensor——请改用is_tensor_like()
。
:param relevant_args: 需要检查__torch_function__
方法的可迭代对象或参数
:type relevant_args: iterable
返回值
如果relevant_args中任何元素实现了__torch_function__
则返回True,否则返回False。
返回类型 : bool
另请参阅
torch.is_tensor_like
检测对象是否为Tensor-like(包括精确的Tensor
)
torch.overrides.is_tensor_like(inp)
如果传入的输入是类张量(Tensor-like)对象,则返回True
。
当前实现中,只要输入对象的类型具有__torch_function__
属性即视为类张量。
示例:
张量的子类通常属于类张量对象。
>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True
内置类型或用户自定义类型通常不具备 Tensor 的特性。
>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False
但是,可以通过实现 __torch_function__
使它们具备类似张量的特性。
>>> class TensorLike:
... @classmethod
... def __torch_function__(cls, func, types, args, kwargs):
... return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func)
如果传入的函数是 torch.Tensor
方法或属性的处理程序(如传入 __torch_function__
时),则返回 True。
注意:对于属性,必须传入其 __get__
方法。
这在以下情况下尤其需要:
1、方法/属性有时不包含 module 槽位
2、它们要求第一个传入参数必须是 torch.Tensor
的实例
示例
>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
返回类型:bool
torch.overrides.wrap_torch_function(dispatcher)
用__torch_function__
相关功能包装给定的函数。
参数
dispatcher (Callable)
– 一个可调用对象,返回传入函数中的类Tensor对象的可迭代集合。
注意:此装饰器可能会降低代码性能。通常,将代码表达为一系列自身支持__torch_function__
的函数就足够了。如果您遇到罕见情况(例如在封装底层库时也需要使其支持类Tensor对象),则可以使用此函数。
示例
>>> def dispatcher(a): # Must have the same signature as func
... return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
torch.package
torch.package
提供了创建包含工件和任意 PyTorch 代码的包的支持。这些包可以被保存、共享,用于在之后的时间或不同的机器上加载和执行模型,甚至可以使用 torch::deploy
部署到生产环境。
本文档包含教程、操作指南、说明和 API 参考,将帮助您了解更多关于 torch.package
的信息以及如何使用它。
警告:此模块依赖于不安全的 pickle
模块。仅解包您信任的数据。
恶意构造的 pickle 数据可能会在解包过程中执行任意代码。切勿解包可能来自不受信任来源或可能被篡改的数据。
更多信息,请查阅 pickle
模块的文档。
教程
打包你的第一个模型
我们提供了一个教程,引导你完成打包和解包一个简单模型的流程,该教程可在 Colab 上查看。完成这个练习后,你将熟悉创建和使用 Torch 包的基本 API。
如何实现…
查看包内包含哪些内容?
将包视为ZIP归档文件处理
torch.package
的容器格式采用ZIP标准,因此任何适用于标准ZIP文件的工具都能用于查看其内容。以下是操作ZIP文件的常用方法:
- 执行
unzip my_package.pt
命令可将torch.package
归档解压到磁盘,便于自由检查其内容。
$ unzip my_package.pt && tree my_package
my_package
├── .data
│ ├── 94304870911616.storage
│ ├── 94304900784016.storage
│ ├── extern_modules
│ └── version
├── models
│ └── model_1.pkl
└── torchvision└── models├── resnet.py└── utils.py
~ cd my_package && cat torchvision/models/resnet.py
...
- Python的
zipfile
模块提供了读写ZIP归档文件内容的标准方法。
from zipfile import ZipFile with ZipFile("my_package.pt") as myzip:file_bytes = myzip.read("torchvision/models/resnet.py")# edit file_bytes in some waymyzip.writestr("torchvision/models/resnet.py", new_file_bytes)
- Vim 原生支持读取 ZIP 压缩包。你甚至可以直接编辑文件并通过
:write
命令将修改写回压缩包!
# add this to your .vimrc to treat `*.pt` files as zip files
au BufReadCmd *.pt call zip#Browse(expand("<amatch>"))~ vi my_package.pt
使用 file_structure()
API
PackageImporter
提供了一个 file_structure()
方法,该方法会返回一个可打印且可查询的 Directory
对象。Directory
对象是一个简单的目录结构,可用于查看 torch.package
的当前内容。
Directory
对象本身可以直接打印,并会输出文件树的表示形式。如需过滤返回的内容,可使用 glob 风格的 include
和 exclude
过滤参数。
with PackageExporter('my_package.pt') as pe:pe.save_pickle('models', 'model_1.pkl', mod)importer = PackageImporter('my_package.pt')
# can limit printed items with include/exclude args
print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"))
print(importer.file_structure()) # will print out all files
Output:
# filtered with glob pattern:
# include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"
─── my_package.pt├── models│ └── model_1.pkl└── torchvision└── models└── utils.py# all files
─── my_package.pt├── .data│ ├── 94304870911616.storage│ ├── 94304900784016.storage│ ├── extern_modules│ └── version├── models│ └── model_1.pkl└── torchvision└── models├── resnet.py└── utils.py
你也可以使用 has_file()
方法查询 Directory
对象。
importer_file_structure = importer.file_structure()
found: bool = importer_file_structure.has_file("package_a/subpackage.py")
查看某个模块为何被列为依赖项?
假设有一个模块 foo
,你想知道为什么 PackageExporter
会将其作为依赖项引入。
PackageExporter.get_rdeps()
方法会返回所有直接依赖 foo
的模块。
如果想查看特定模块 src
如何依赖 foo
,可以使用 PackageExporter.all_paths()
方法,该方法会返回一个 DOT 格式的图表,展示 src
和 foo
之间的所有依赖路径。
如果只想查看 PackageExporter
的完整依赖关系图,可以使用 PackageExporter.dependency_graph_string()
方法。
如何在打包时包含任意资源并后续访问?
PackageExporter
提供了三个方法:save_pickle
、save_text
和 save_binary
,允许你将 Python 对象、文本和二进制数据保存到包中。
with torch.PackageExporter("package.pt") as exporter:# Pickles the object and saves to `my_resources/tensor.pkl` in the archive.exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4))exporter.save_text("config_stuff", "words.txt", "a sample string")exporter.save_binary("raw_data", "binary", my_bytes)
PackageImporter
提供了三个互补方法:load_pickle
、load_text
和 load_binary
,用于从包中加载 Python 对象、文本数据和二进制数据。
importer = torch.PackageImporter("package.pt")
my_tensor = importer.load_pickle("my_resources", "tensor.pkl")
text = importer.load_text("config_stuff", "words.txt")
binary = importer.load_binary("raw_data", "binary")
自定义类的打包方式
torch.package
允许自定义类的打包方式。这一行为通过以下两种方式实现:在类上定义方法 __reduce_package__
,并定义对应的解包函数。这与为 Python 常规的 pickle 过程定义 __reduce__
类似。
操作步骤:
1、在目标类上定义方法 __reduce_package__(self, exporter: PackageExporter)
。该方法负责将类实例保存到包中,并应返回一个元组,包含对应的解包函数及调用该函数所需的参数。当 PackageExporter
遇到目标类的实例时,会调用此方法。
2、为类定义一个解包函数。该解包函数负责重建并返回类的实例。其函数签名的第一个参数应为 PackageImporter
实例,其余参数由用户自定义。
# foo.py [Example of customizing how class Foo is packaged]
from torch.package import PackageExporter, PackageImporter
import timeclass Foo:def __init__(self, my_string: str):super().__init__()self.my_string = my_stringself.time_imported = 0self.time_exported = 0def __reduce_package__(self, exporter: PackageExporter):"""Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` whensaving an instance of this object. This method should do the work to save this object inside of the ``torch.package`` archive.Returns function w/ arguments to load the object from a ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function."""# use this pattern to ensure no naming conflicts with normal dependencies, # anything saved under this module name shouldn't conflict with other# items in the packagegenerated_module_name = f"foo-generated._{exporter.get_unique_id()}"exporter.save_text(generated_module_name, "foo.txt", self.my_string + ", with exporter modification!", )time_exported = time.clock_gettime(1)# returns de-packaging function w/ arguments to invoke with return (unpackage_foo, (generated_module_name, time_exported,))def unpackage_foo(importer: PackageImporter, generated_module_name: str, time_exported: float
) -Foo:"""Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` functionwhen depickling a Foo object.Performs work of loading and returning a Foo instance from a ``torch.package`` archive."""time_imported = time.clock_gettime(1)foo = Foo(importer.load_text(generated_module_name, "foo.txt"))foo.time_imported = time_importedfoo.time_exported = time_exportedreturn foo
# example of saving instances of class Fooimport torch
from torch.package import PackageImporter, PackageExporter
import foofoo_1 = foo.Foo("foo_1 initial string")
foo_2 = foo.Foo("foo_2 initial string") with PackageExporter('foo_package.pt') as pe:# save as normal, no extra work necessarype.save_pickle('foo_collection', 'foo1.pkl', foo_1)pe.save_pickle('foo_collection', 'foo2.pkl', foo_2)pi = PackageImporter('foo_package.pt')
print(pi.file_structure())
imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl')
print(f"foo_1 string: '{imported_foo.my_string}'")
print(f"foo_1 export time: {imported_foo.time_exported}")
print(f"foo_1 import time: {imported_foo.time_imported}")
# output of running above script
─── foo_package├── foo-generated│ ├── _0│ │ └── foo.txt│ └── _1│ └── foo.txt├── foo_collection│ ├── foo1.pkl│ └── foo2.pkl└── foo.pyfoo_1 string: 'foo_1 initial string, with reduction modification!'
foo_1 export time: 9857706.650140837
foo_1 import time: 9857706.652698385
如何在源码中检测当前是否运行在包环境中?
PackageImporter
会在初始化每个模块时为其添加 __torch_package__
属性。你的代码可以通过检查该属性是否存在,来判断当前是否处于打包后的运行环境中。
# In foo/bar.py:if "__torch_package__" in dir(): # true if the code is being loaded from a packagedef is_in_package():return TrueUserException = Exception
else:def is_in_package():return FalseUserException = UnpackageableException
现在,代码的行为会根据它是通过Python环境正常导入还是从torch.package
导入而有所不同。
from foo.bar import is_in_packageprint(is_in_package()) # Falseloaded_module = PackageImporter(my_package).import_module("foo.bar")
loaded_module.is_in_package() # True
警告:通常情况下,让代码在打包前后表现不一致是一种不良实践。这会导致难以调试的问题,且问题表现会因代码导入方式的不同而敏感变化。如果你的包预计会被频繁使用,建议重构代码,确保无论以何种方式加载,其行为都保持一致。
如何将代码补丁打入包中?
PackageExporter
提供了 save_source_string()
方法,允许你将任意 Python 源代码保存到指定的模块中。
with PackageExporter(f) as exporter:# Save the my_module.foo available in your current Python environment.exporter.save_module("my_module.foo")# This saves the provided string to my_module/foo.py in the package archive.# It will override the my_module.foo that was previously saved.exporter.save_source_string("my_module.foo", textwrap.dedent("""\def my_function():print('hello world')"""))# If you want to treat my_module.bar as a package# (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py)# pass is_package=True, exporter.save_source_string("my_module.bar", "def foo(): print('hello')\n", is_package=True)importer = PackageImporter(f)
importer.import_module("my_module.foo").my_function() # prints 'hello world'
如何从打包代码中访问包内容?
PackageImporter
实现了 importlib.resources API,用于从包内部访问资源。
with PackageExporter(f) as exporter:# saves text to my_resource/a.txt in the archiveexporter.save_text("my_resource", "a.txt", "hello world!")# saves the tensor to my_pickle/obj.pklexporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2))# see below for module contentsexporter.save_module("foo")exporter.save_module("bar")
importlib.resources
API 允许从打包代码中访问资源。
# foo.py:
import importlib.resources
import my_resource# returns "hello world!"
def get_my_resource():return importlib.resources.read_text(my_resource, "a.txt")
推荐使用 importlib.resources
来访问打包代码中的包内容,因为它符合 Python 标准。不过,也可以直接从打包代码中访问父级 PackageImporter
实例本身。
# bar.py:
import torch_package_importer # this is the PackageImporter that imported this module.# Prints "hello world!", equivalent to importlib.resources.read_text
def get_my_resource():return torch_package_importer.load_text("my_resource", "a.txt")# You also do things that the importlib.resources API does not support, like loading
# a pickled object from the package.
def get_my_pickle():return torch_package_importer.load_pickle("my_pickle", "obj.pkl")
区分打包代码与非打包代码
要判断一个对象的代码是否来自 torch.package
,可使用 torch.package.is_from_package()
函数。
注意:若对象来自某个包但其定义源自标记为 extern
的模块或来自 stdlib
,此检查将返回 False
。
importer = PackageImporter(f)
mod = importer.import_module('foo')
obj = importer.load_pickle('model', 'model.pkl')
txt = importer.load_text('text', 'my_test.txt')assert is_from_package(mod)
assert is_from_package(obj)
assert not is_from_package(txt) # str is from stdlib, so this will return False
如何重新导出已导入的对象?
要通过新的 PackageExporter
重新导出之前由 PackageImporter
导入的对象,必须让导出器知晓原始导入器的存在,这样才能正确找到对象依赖项的源代码。
importer = PackageImporter(f)
obj = importer.load_pickle("model", "model.pkl")# re-export obj in a new package with PackageExporter(f2, importer=(importer, sys_importer)) as exporter:exporter.save_pickle("model", "model.pkl", obj)
如何打包 TorchScript 模块?
要打包 TorchScript 模型,可以使用与其他对象相同的 save_pickle
和 load_pickle
API。TorchScript 对象作为属性或子模块时也支持直接保存,无需额外操作。
# save TorchScript just like any other object with PackageExporter(file_name) as e:e.save_pickle("res", "script_model.pkl", scripted_model)e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule)
# load as normal
importer = PackageImporter(file_name)
loaded_script = importer.load_pickle("res", "script_model.pkl")
loaded_mixed = importer.load_pickle("res", "mixed_model.pkl"
说明
torch.package
格式概述
torch.package
文件是一个 ZIP 归档文件,通常使用 .pt
扩展名。ZIP 归档内包含两类文件:
- 框架文件:存放在
.data/
目录下 - 用户文件:其余所有文件
例如,以下是一个完整打包的 torchvision
ResNet 模型的结构示例:
resnet
├── .data # All framework-specific data is stored here.
│ │ # It's named to avoid conflicts with user-serialized code.
│ ├── 94286146172688.storage # tensor data
│ ├── 94286146172784.storage
│ ├── extern_modules # text file with names of extern modules (e.g. 'torch')
│ ├── version # version metadata
│ ├── ...
├── model # the pickled model
│ └── model.pkl
└── torchvision # all code dependencies are captured as source files└── models├── resnet.py└── utils.py
框架文件
.data/
目录由 torch.package 所有,其内容被视为私有实现细节。torch.package
格式不保证 .data/
目录内容的具体结构,但所有改动都将保持向后兼容性(即新版本的 PyTorch 始终能够加载旧版 torch.package
)。
目前,.data/
目录包含以下内容:
version
:序列化格式的版本号,用于让torch.package
的导入基础设施知道如何加载该包。extern_modules
:被视为extern
的模块列表。extern
模块将使用加载环境的系统导入器进行导入。*.storage
:序列化的张量数据。
.data
├── 94286146172688.storage
├── 94286146172784.storage
├── extern_modules
├── version
├── ...
用户文件
归档中的所有其他文件都是由用户放置的。其目录结构与 Python 的常规包完全一致。若想深入了解 Python 的打包机制,请参阅这篇文章(内容稍有过时,具体实现细节请以 Python 参考文档为准)。
<package root>
├── model # the pickled model
│ └── model.pkl
├── another_package
│ ├── __init__.py
│ ├── foo.txt # a resource file , see importlib.resources
│ └── ...
└── torchvision└── models├── resnet.py # torchvision.models.resnet└── utils.py # torchvision.models.utils
torch.package
如何查找代码依赖项
分析对象的依赖关系
当你调用 save_pickle(obj, ...)
时,PackageExporter
会正常地对对象进行 pickle 序列化。随后,它会使用标准库模块 pickletools
来解析 pickle 字节码。
在 pickle 序列化过程中,对象会与一个 GLOBAL
操作码一起保存,该操作码描述了如何找到对象类型的实现位置,例如:
GLOBAL 'torchvision.models.resnet Resnet`
依赖解析器会收集所有GLOBAL
操作,并将它们标记为待序列化对象的依赖项。
有关序列化及pickle格式的更多信息,请参阅Python官方文档。
分析模块依赖关系
当识别出某个Python模块作为依赖项时,torch.package
会遍历该模块的Python抽象语法树(AST)表示,并查找其中的导入语句。它完全支持标准导入形式:from x import y
、import z
、from w import v as u
等。当遇到这些导入语句时,torch.package
会将导入的模块注册为依赖项,随后以同样的AST遍历方式解析这些依赖模块。
注意:AST解析对__import__(...)
语法的支持有限,且不支持importlib.import_module
调用。通常不应期望torch.package
能检测到动态导入。
依赖管理
torch.package
会自动发现你的代码和对象所依赖的 Python 模块。这一过程称为依赖解析。
对于依赖解析器找到的每个模块,你必须指定一个操作来处理它。
允许的操作包括:
intern
:将该模块打包到包中。extern
:声明该模块为包的外部依赖项。mock
:将该模块替换为存根。deny
:依赖此模块会在包导出时引发错误。
此外,还有一个重要的操作虽然技术上不属于 torch.package
的一部分:
- 重构:移除或更改代码中的依赖项。
注意,操作仅针对整个 Python 模块定义,无法仅打包模块中的某个函数或类而忽略其余部分。
这是有意为之的设计。Python 并未提供模块内对象之间的清晰边界,模块是依赖组织的唯一明确定义单元,因此 torch.package
也采用这一标准。
操作通过模式应用于模块。模式可以是模块名称(如 "foo.bar"
)或通配符(如 "foo.**"
)。你可以使用 PackageExporter
上的方法将模式与操作关联,例如:
my_exporter.intern("torchvision.**")
my_exporter.extern("numpy")
如果模块匹配某个模式,就会对其应用相应的操作。对于给定的模块,系统会按照模式定义的顺序依次检查,并执行第一个匹配的操作。
intern
如果一个模块被标记为intern
,它将被放入包中。
此操作适用于你的模型代码,或任何你想打包的相关代码。例如,如果你要打包torchvision
中的ResNet模型,就需要将模块torchvision.models.resnet
标记为intern
。
当导入包时,如果打包的代码尝试导入一个被标记为intern
的模块,PackageImporter
会在你的包内查找该模块。如果找不到该模块,则会抛出错误。这确保了每个PackageImporter
与加载环境隔离——即使my_interned_module
同时在你的包和加载环境中可用,PackageImporter
也只会使用包内的版本。
注意:只有Python源码模块可以被标记为intern
。其他类型的模块(如C扩展模块和字节码模块)如果尝试标记为intern
会抛出错误。这类模块需要被标记为mock
或extern
。
extern
如果一个模块被声明为extern
,它将不会被包含在包中。相反,该模块会被添加到当前包的外部依赖列表中。你可以在package_exporter.extern_modules
中找到这个列表。
当导入包时,如果打包后的代码尝试导入一个被声明为extern
的模块,PackageImporter
会使用默认的Python导入器来查找该模块,就像执行了importlib.import_module("my_externed_module")
一样。如果找不到该模块,则会抛出错误。
通过这种方式,你可以在包中依赖第三方库(如numpy
和scipy
),而无需将它们一并打包。
警告:如果任何外部库发生了不向后兼容的更改,你的包可能无法加载。如果需要长期保证包的复现性,请尽量减少使用extern
。
mock
当一个模块被mock
时,它不会被打包。取而代之的是一个存根模块会被打包到该位置。这个存根模块允许你从中获取对象(因此from my_mocked_module import foo
不会报错),但任何使用该对象的操作都会引发NotImplementedError
。
mock
应该用于那些你"确定"在加载的包中不需要,但仍希望在非打包内容中可用的代码。例如初始化/配置代码,或仅用于调试/训练的代码。
警告:一般来说,mock
应该作为最后手段使用。它会导致打包代码和非打包代码之间的行为差异,可能引发后续混淆。建议优先通过重构代码来移除不需要的依赖项。
代码重构
管理依赖关系的最佳方式就是彻底消除依赖!通过重构代码,我们往往可以移除不必要的依赖。以下是编写低依赖代码的指导原则(这些原则本身也是优秀的编码实践):
只引入真正用到的内容。不要在代码中保留未使用的导入项。依赖解析器无法智能识别这些未使用的导入,仍会尝试处理它们。
精确限定导入范围。例如,与其写import foo
然后在代码中使用foo.bar.baz
,不如直接写from foo.bar import baz
。这种方式能更精准地声明实际依赖(foo.bar
),让依赖解析器明确你不需要引入整个foo
包。
将包含无关功能的大文件拆分为小模块。如果你的utils
模块混杂了大量无关功能,任何依赖该模块的代码都不得不引入许多无关依赖——即使你只需要其中一小部分功能。更好的做法是创建功能单一的小模块,这些模块可以彼此独立地打包。
模式
模式允许您通过简洁的语法来指定模块组。其语法和行为遵循 Bazel/Buck 的 glob() 规范。
我们将尝试与模式匹配的模块称为候选对象。候选对象由通过分隔符字符串分隔的多个段组成,例如 foo.bar.baz
。
一个模式包含一个或多个段。段可以是以下类型:
- 字面字符串(如
foo
),表示精确匹配 - 包含通配符的字符串(如
torch
或foo*baz*
)。通配符可以匹配任意字符串,包括空字符串 - 双通配符(
**
)。这将匹配零个或多个完整段
示例:
torch.**
:匹配torch
及其所有子模块,例如torch.nn
和torch.nn.functional
torch.*
:匹配torch.nn
或torch.functional
,但不匹配torch.nn.functional
或torch
torch*.**
:匹配torch
、torchvision
及其所有子模块
在指定操作时,您可以传入多个模式,例如
exporter.intern(["torchvision.models.**", "torchvision.utils.**"])
模块将匹配此操作,只要符合其中任一模式。
您还可以指定要排除的模式,例如
exporter.mock("**", exclude=["torchvision.**"])
如果模块匹配任何排除模式,则该模块不会与此操作匹配。在本示例中,我们模拟了除 torchvision
及其子模块之外的所有模块。
当一个模块可能匹配多个操作时,将优先采用第一个定义的操作。
torch.package
的注意事项
避免在模块中使用全局状态
Python 可以非常方便地在模块作用域内绑定对象和运行代码。这通常没有问题——毕竟函数和类也是通过这种方式绑定到名称的。但当你定义一个模块作用域内的可变对象时,就会引入可变的全局状态,情况就变得复杂了。
可变全局状态非常有用——它可以减少样板代码、允许开放注册到表中等等。但除非非常谨慎地使用,否则在与 torch.package
一起使用时可能会导致问题。
每个 PackageImporter
都会为其内容创建一个独立的环境。这很好,因为它意味着我们可以加载多个包并确保它们彼此隔离。但当模块的编写方式假设存在共享的可变全局状态时,这种行为可能会导致难以调试的错误。
类型在包与加载环境之间不共享
通过 PackageImporter
导入的任何类,都将是该导入器特有的类版本。例如:
from foo import MyClassmy_class_instance = MyClass()with PackageExporter(f) as exporter:exporter.save_module("foo")importer = PackageImporter(f)
imported_MyClass = importer.import_module("foo").MyClassassert isinstance(my_class_instance, MyClass) # works
assert isinstance(my_class_instance, imported_MyClass) # ERROR!
在这个示例中,MyClass
和 imported_MyClass
不是同一类型。虽然在这个特定例子中,MyClass
和 imported_MyClass
的实现完全相同,你可能会认为它们可以视为同一个类。但设想一下,如果 imported_MyClass
来自一个旧版本的包,其中 MyClass
的实现完全不同——这种情况下,将它们视为同一个类是不安全的。
实际上,每个导入器都有一个唯一标识类的前缀:
print(MyClass.__name__) # prints "foo.MyClass"
print(imported_MyClass.__name__) # prints <torch_package_0>.foo.MyClass
这意味着当其中一个参数来自某个包而另一个不是时,你不应期望isinstance
检查能正常工作。如果需要此功能,请考虑以下选项:
- 采用鸭子类型(直接使用类而非显式检查其是否属于给定类型)。
- 将类型关系明确作为类契约的一部分。例如,可以添加属性标签
self.handler = "handle_me_this_way"
,并让客户端代码检查handler
的值而非直接检查类型。
torch.package
如何实现包之间的隔离
每个 PackageImporter
实例都会为其模块和对象创建一个独立的隔离环境。包中的模块只能导入其他已打包的模块或被标记为 extern
的模块。如果使用多个 PackageImporter
实例加载同一个包,将会得到多个互不干扰的独立环境。
这一机制是通过扩展 Python 的导入系统实现的,PackageImporter
是一个自定义导入器。它提供了与 importlib
导入器相同的核心 API,即实现了 import_module
和 __import__
方法。
当调用 PackageImporter.import_module()
时,PackageImporter
会像系统导入器一样构造并返回一个新模块。不同之处在于,它会修改返回的模块,使其在后续导入请求时使用当前 PackageImporter
实例(即从包内查找资源),而不是从用户的 Python 环境中搜索。
名称修饰(Mangling)
为了避免混淆(“这个 foo.bar
对象是来自我的包,还是来自 Python 环境?”),PackageImporter
会通过添加修饰前缀来修改所有导入模块的 __name__
和 __file__
属性。
对于 __name__
,像 torchvision.models.resnet18
这样的名称会被修饰为 <torch_package_0>.torchvision.models.resnet18
。
对于 __file__
,像 torchvision/models/resnet18.py
这样的路径会被修饰为 <torch_package_0>.torchvision/modules/resnet18.py
。
名称修饰有助于避免不同包之间模块名的意外冲突,并通过使堆栈跟踪和打印语句更清晰地显示它们是否引用打包代码来辅助调试。有关名称修饰的开发者详细信息,请参阅 torch/package/
目录下的 mangling.md
文件。
API 参考
class torch.package.PackagingError(dependency_graph, debug=False)
当导出包出现问题时,会引发此异常。
PackageExporter
会尝试收集所有错误并一次性展示给您。
class torch.package.EmptyMatchError
当模拟对象(mock)或外部依赖(extern)被标记为allow_empty=False
,但在打包过程中未匹配到任何模块时,会抛出此异常。
class torch.package.PackageExporter(f, importer=<torch.package.importer._SysImporter object>, debug=False)
导出器(Exporters)允许你将代码包、序列化的Python数据以及任意二进制和文本资源打包成自包含的独立包。
导入器(Imports)能够以封闭方式加载这些代码,确保代码从包内加载而非通过常规Python导入系统。这种机制使得PyTorch模型代码和数据可以被打包,以便在服务器上运行或用于未来的迁移学习。
包中的代码在创建时会从原始源文件逐份复制,其文件格式采用特殊组织的zip压缩包。包的使用者可以解压该包并编辑代码,以实现自定义修改。
包的导入器会确保模块代码只能从包内加载,除非模块通过extern()
明确声明为外部依赖。zip压缩包中的extern_modules
文件列出了该包所有外部依赖的模块。
这种机制避免了"隐式"依赖问题——即包在本地运行时能正常导入本地安装的依赖,但当包被复制到其他机器时就会运行失败。
当源代码被添加到包中时,导出器可选择性地扫描代码以发现更多依赖关系(通过设置dependencies=True
)。它会查找import语句,将相对引用解析为完整模块名,并执行用户指定的操作(参见:extern()
、mock()
和intern()
)。
__init__(f, importer=<torch.package.importer._SysImporter object>, debug=False)
创建一个导出器。
参数
f ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ]])
- 导出目标位置。可以是包含文件名的string
/Path
对象,也可以是二进制I/O对象。importer ( Union [Importer*,* Sequence [Importer]])
- 如果传入单个Importer,则使用它来搜索模块。如果传入一个importer序列,则会基于它们构建一个OrderedImporter
。debug ([bool])
- 如果设为True,会将损坏模块的路径添加到PackagingErrors中。
add_dependency(module_name, dependencies=True)
根据用户指定的模式,将给定模块添加到依赖关系图中。
all_paths(src, dst)
返回从源节点到目标节点所有路径的子图点表示形式。
返回值:包含从源节点到目标节点所有路径的点表示字符串。
参考文档:Graphviz语言规范
返回类型:str
close()
将包写入文件系统。调用close()
方法后,所有后续操作都将无效。
建议改用资源守卫语法:
with PackageExporter("file.zip") as e:...
denied_modules()
返回当前被拒绝的所有模块。
返回值:一个包含该包中被拒绝模块名称的列表。
返回类型:list [str]
deny(include, *, exclude=())
根据给定的glob模式,从包可导入的模块列表中屏蔽匹配名称的模块。
如果发现任何匹配包的依赖项,将抛出 PackagingError
错误。
参数
include (Union[List[str],* str])
- 可以是字符串(例如"my_package.my_subpackage"
)或模块名称列表,用于指定需要外部化的模块。也支持glob风格的模式,详见mock()
。exclude (Union[List[str],* str])
- 可选参数,用于排除与include字符串匹配的部分模式。
dependency_graph_string()
返回包中依赖关系的有向图字符串表示形式。
返回值:包中依赖关系的字符串表示形式。
返回类型:str
extern(include, *, exclude=(), allow_empty=True)
将 module
包含在包可导入的外部模块列表中。
这将阻止依赖项发现机制将其保存到包中。导入器会直接从标准导入系统加载外部模块。
外部模块的代码也必须存在于加载包的进程中。
参数
include (Union[List[str],* str])
– 字符串(例如"my_package.my_subpackage"
)或待外部化的模块名称字符串列表。也可使用通配符模式,如mock()
所述。exclude (Union[List[str],* str])
– 可选参数,用于排除与 include 字符串匹配的部分模式。allow_empty ([bool])
– 可选标志,指定本次调用extern
方法所定义的外部模块是否必须在打包时匹配到某些模块。若以allow_empty=False
添加外部模块通配符模式,且在匹配到任何模块前调用close()
(显式调用或通过__exit__
调用),则会抛出异常。若allow_empty=True
则不会抛出此类异常。
(注:严格保留所有代码块、术语标记及链接格式,技术术语如extern
、mock()
等未作翻译,被动语态已转换为主动表达,长句进行了合理拆分)
externed_modules()
返回当前被外部化的所有模块。
返回值:一个包含该包中将被外部化的模块名称的列表。
返回类型:list [str]
get_rdeps(module_name)
返回所有依赖于模块 module_name
的模块列表。
返回值:一个包含依赖 module_name
的模块名称列表。
返回类型:list [str]
get_unique_id()
获取一个ID。该ID保证在此包中只会被分配一次。
返回类型:str
intern(include, *, exclude=(), allow_empty=True)
指定需要打包的模块。模块必须匹配某些intern
模式才能被包含在包中,并递归处理其依赖项。
参数
include (Union[List[str],* str])
– 可以是字符串(例如"my_package.my_subpackage")或模块名称列表,用于指定需要外部化的模块。该参数也支持glob风格的模式匹配,具体说明见mock()
文档。exclude (Union[List[str],* str])
– 可选参数,用于排除与include模式匹配的特定模式。allow_empty ([bool])
– 可选标志,指定通过intern
方法设置的内部模块在打包时是否必须匹配到某些模块。如果添加intern
模块glob模式时设置allow_empty=False
,且在调用close()
(显式调用或通过__exit__
)时没有任何模块匹配该模式,则会抛出异常。若设置allow_empty=True
,则不会抛出此类异常。
interned_modules()
返回当前被内部化的所有模块。
返回值:一个包含将被此包内部化的模块名称的列表。
返回类型:list [str]
mock(include, *, exclude=(), allow_empty=True)
用模拟实现替换某些必需的模块。被模拟的模块将返回一个虚假对象,用于处理任何从其访问的属性。由于我们采用逐文件复制的方式,依赖解析有时会找到被模型文件导入但实际功能从未使用的文件(例如自定义序列化代码或训练辅助工具)。
使用此函数可以模拟这些功能,而无需修改原始代码。
参数
include (Union[List[str],* str])
–
一个字符串(例如 "my_package.my_subpackage"
)或字符串列表,表示需要被模拟的模块名称。字符串也可以使用通配符模式,可能匹配多个模块。任何符合此模式字符串的必需依赖项都将被自动模拟。
示例:
'torch.**'
– 匹配 torch
及其所有子模块,例如 'torch.nn'
和 'torch.nn.functional'
'torch.*'
– 匹配 'torch.nn'
或 'torch.functional'
,但不匹配 'torch.nn.functional'
exclude (Union[List[str],* str])
– 一个可选模式,用于排除某些匹配include
字符串的模式。
例如,include='torch.**', exclude='torch.foo'
将模拟除'torch.foo'
之外的所有 torch 包。默认值为[]
。allow_empty ([bool])
– 一个可选标志,指定通过调用mock()
方法指定的模拟实现是否必须在打包过程中匹配到某些模块。如果以allow_empty=False
添加模拟,并且在调用close()
(显式调用或通过__exit__
)时该模拟未匹配到导出包所使用的模块,则会抛出异常。
如果 allow_empty=True
,则不会抛出此类异常。
mocked_modules()
返回当前被模拟的所有模块。
返回值:包含本包中将被模拟的模块名称列表。
返回类型:list [str]
register_extern_hook(hook)
在导出器上注册一个外部钩子。
每当有模块匹配 extern()
模式时,就会调用该钩子。
钩子函数需要遵循以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子将按照注册顺序被调用。
返回值:一个句柄,可用于通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
register_intern_hook(hook)
在导出器上注册一个内部钩子。
每当模块匹配 intern()
模式时,都会调用该钩子。
它应具有以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子将按照注册顺序被调用。
返回值:一个句柄,可用于通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
register_mock_hook(hook)
在导出器上注册一个模拟钩子。
每当模块匹配 mock()
模式时,该钩子就会被调用。
它应具有以下签名:
hook(exporter: PackageExporter, module_name: str) -None
钩子函数将按照注册顺序依次调用。
返回值:返回一个句柄,可通过调用 handle.remove()
来移除已添加的钩子。
返回类型:torch.utils.hooks.RemovableHandle
save_binary(package, resource, binary)
将原始字节保存到包中。
参数
package (str)
– 该资源所属的模块包名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称,用于加载时识别。binary (str)
– 要保存的数据。
save_module(module_name, dependencies=True)
将 module
的代码保存到包中。模块代码的解析过程是:先通过 importers
路径查找模块对象,然后利用其 __file__
属性定位源代码。
参数
module_name (str)
– 例如my_package.my_subpackage
,代码将被保存以提供该包的实现代码。dependencies ([bool], 可选)
– 若设为True
,则会扫描源代码中的依赖项。
save_pickle(package, resource, obj, dependencies=True, pickle_protocol=3)
使用pickle将Python对象保存到归档文件中。功能等同于torch.save()
,但会保存到归档而非独立文件。标准pickle不会保存代码,仅保存对象。
如果dependencies
参数为True,此方法还会扫描被pickle的对象,识别重建它们所需的模块,并保存相关代码。
要保存type(obj).__name__
为my_module.MyObject
的对象时,my_module.MyObject
必须能根据importer
顺序解析为对象的类。当保存先前已打包的对象时,importer
列表中必须包含import_module
方法才能正常工作。
参数
package (str)
- 该资源所属的模块包名称(例如"my_package.my_subpackage"
)。resource (str)
- 资源的唯一名称,用于加载时识别。obj (Any)
- 要保存的对象,必须可被pickle序列化。dependencies ([bool], 可选)
- 若为True
,则会扫描源代码中的依赖项。
save_source_file(module_name, file_or_directory, dependencies=True)
将本地文件系统中的 file_or_directory
添加到源码包中,为 module_name
提供代码。
参数
module_name (str)
– 例如"my_package.my_subpackage"
,代码将被保存以提供该包的代码。file_or_directory (str)
– 代码文件或目录的路径。如果是目录,将递归复制该目录中的所有 Python 文件,使用save_source_file()
。如果文件名为"/__init__.py"
,则该代码被视为一个包。dependencies ([bool], 可选)
– 如果为True
,则会扫描源码中的依赖项。
save_source_string(module_name, src, is_package=False, dependencies=True)
将src
作为导出包中module_name
的源代码添加。
参数
module_name (str)
– 例如my_package.my_subpackage
,代码将被保存以提供该包的源代码。src (str)
– 要为此包保存的Python源代码。is_package ([bool], 可选)
– 如果为True
,则将此模块视为包。包允许包含子模块(例如my_package.my_subpackage.my_subsubpackage
),并且可以在其中保存资源。默认为False
。dependencies ([bool], 可选)
– 如果为True
,则会扫描源代码中的依赖项。
save_text(package, resource, text)
将文本数据保存到包中。
参数
package (str)
– 该资源所属模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称,用于加载时识别。text (str)
– 要保存的内容。
class torch.package.PackageImporter(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)
导入器(Importers)允许您加载由PackageExporter
写入包的代码。
代码以密封方式加载,使用包中的文件而非常规Python导入系统。这使得PyTorch模型代码和数据可以被打包,从而能在服务器上运行或用于未来的迁移学习。
包导入器确保模块中的代码只能从包内部加载,除非是导出时明确列为外部的模块。
zip存档中的extern_modules
文件列出了包外部依赖的所有模块。
这避免了"隐式"依赖问题——即包在本地运行时能正常工作(因为导入了本地安装的包),但当包被复制到其他机器时就会失败。
__init__(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)
打开 file_or_buffer
以进行导入操作。此操作会检查导入的包是否仅依赖 module_allowed
所允许的模块。
参数
file_or_buffer ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ], PyTorchFileReader])
– 类文件对象(需实现read()
、readline()
、tell()
和seek()
方法)、字符串或包含文件名的os.PathLike
对象。module_allowed (Callable[[str],* [bool]], optional)
– 用于判断是否应允许外部提供模块的方法。可用于确保加载的包不依赖服务器不支持的模块。默认允许所有模块。
抛出异常
ImportError` – 如果包尝试使用被禁止的模块。
file_structure(*, include='**', exclude=())
返回包 zipfile 的文件结构表示。
参数
include (Union[List[str],* str])
– 可选字符串(如"my_package.my_subpackage"
)或可选字符串列表,用于指定要包含在 zipfile 表示中的文件名。也可以是 glob 风格的模式,如PackageExporter.mock()
中所述。exclude (Union[List[str],* str])
– 可选模式,用于排除名称匹配该模式的文件。
返回
Directory
返回类型:Directory
id()
返回 torch.package 用于区分 PackageImporter
实例的内部标识符。
格式类似:
<torch_package_0>
import_module(name, package=None)
如果模块尚未加载,则从包中加载该模块并返回。模块会被加载到导入者的本地命名空间,并出现在 self.modules
中而非 sys.modules
里。
参数
name (str)
– 要加载模块的完整限定名。package ([type], 可选)
– 未使用,但为了与importlib.import_module
的函数签名保持一致而保留。默认为None
。
返回值
(可能已加载的)模块对象。
返回类型
types.ModuleType
load_binary(package, resource)
加载原始字节数据。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。
返回值:已加载的数据。
返回类型:bytes
load_pickle(package, resource, map_location=None)
从包中反序列化资源,加载构造对象所需的所有模块
使用 import_module()
。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。map_location
– 传递给 torch.load 以确定张量如何映射到设备。默认为None
。
返回
反序列化后的对象。
返回类型:任意
load_text(package, resource, encoding='utf-8', errors='strict')
加载字符串。
参数
package (str)
– 模块包的名称(例如"my_package.my_subpackage"
)。resource (str)
– 资源的唯一名称。encoding (str, 可选)
– 传递给decode
。默认为'utf-8'
。errors (str, 可选)
– 传递给decode
。默认为'strict'
。
返回值:加载的文本。
返回类型:str
python_version()
返回用于创建此软件包的 Python 版本。
注意:此功能为实验性质,不具备向前兼容性。计划后续将其迁移至锁文件中。
返回值:Optional[str]
返回 Python 版本号(例如 3.8.9),如果该软件包未存储版本信息则返回 None
class torch.package.Directory(name, is_dir)
一种文件结构表示形式。它以目录节点形式组织,每个节点包含其子目录列表。通过调用 PackageImporter.file_structure()
可为包创建目录结构。
has_file(filename)
检查文件是否存在于 Directory
中。
参数
filename (str)
- 要搜索的文件路径。
返回值
如果 Directory
包含指定文件则返回 True。
返回类型:bool
torch.profiler
概述
PyTorch Profiler 是一款用于在训练和推理过程中收集性能指标的工具。通过其上下文管理器 API,开发者可以深入分析模型中最耗时的算子、检查输入张量形状和调用堆栈、研究设备内核活动,并可视化执行轨迹。
注意: torch.autograd
模块中的旧版 API 已被视为遗留接口,未来将被弃用。
API 参考
class torch.profiler._KinetoProfile(*, activities=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, custom_trace_id_callback=None)
底层分析器封装了自动梯度分析功能
参数
activities (iterable)
– 用于分析的活跃组列表(CPU、CUDA),支持以下值:
torch.profiler.ProfilerActivity.CPU
, torch.profiler.ProfilerActivity.CUDA
, torch.profiler.ProfilerActivity.XPU
.
默认值:ProfilerActivity.CPU 和(可用时)ProfilerActivity.CUDA 或(可用时)ProfilerActivity.XPU。
record_shapes ([bool])
– 保存算子输入形状信息。profile_memory ([bool])
– 跟踪张量内存分配/释放(详见export_memory_timeline
)。with_stack ([bool])
– 记录算子的源码信息(文件和行号)。with_flops ([bool])
– 使用公式估算特定算子(矩阵乘法和2D卷积)的浮点运算次数。with_modules ([bool])
– 记录与算子调用栈对应的模块层次结构(包括函数名)。
例如:如果模块A的forward调用模块B的forward(其中包含aten::add算子),那么aten::add的模块层次结构就是A.B
注意:目前该功能仅支持TorchScript模型,不支持eager模式模型。
experimental_config (_ExperimentalConfig)
– 一组实验性选项,供Kineto等分析器库使用。注意不保证向后兼容性。execution_trace_observer (ExecutionTraceObserver)
– PyTorch执行轨迹观察器对象。
PyTorch执行轨迹提供了基于图的AI/ML工作负载表示,支持回放基准测试、模拟器和仿真器。
当包含此参数时,观察器的start()和stop()方法将与PyTorch分析器在同一时间窗口被调用。
acc_events ([bool])
– 启用跨多个分析周期的FunctionEvents累积功能
注意:此API为实验性质,未来可能变更。
启用形状和调用栈跟踪会导致额外开销。
当指定record_shapes=True时,分析器会临时持有张量的引用,这可能会阻止某些依赖引用计数的优化,并引入额外的张量拷贝。
add_metadata(key, value)
向跟踪文件中添加用户定义的元数据,包含字符串键和字符串值
add_metadata_json(key, value)
向跟踪文件中添加用户自定义的元数据,包含字符串键和有效的JSON值
events()
返回未聚合的性能分析事件列表,可用于跟踪回调或在性能分析结束后使用
export_chrome_trace(path)
以 Chrome JSON 格式导出收集的跟踪数据。如果启用了 kineto,则仅导出调度中的最后一个周期。
export_memory_timeline(path, device=None)
从分析器收集的内存事件信息树中导出指定设备的数据,并生成时间线图表。使用export_memory_timeline
可导出三种文件格式,通过path
参数的后缀控制:
- 如需生成HTML兼容的图表,使用
.html
后缀。内存时间线图表将以PNG格式嵌入HTML文件中。 - 如需导出由
[时间戳, [按类别划分的内存大小]]
组成的数据点(其中times
是时间戳,sizes
是每个类别的内存使用量),根据后缀选择保存为JSON(.json
)或gzip压缩的JSON(.json.gz
)。 - 如需原始内存事件数据,使用
.raw.json.gz
后缀。每个原始内存事件将包含(时间戳, 操作类型, 字节数, 类别)
,其中:action
是[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]
之一category
来自torch.profiler._memory_profiler.Category
枚举
输出:内存时间线数据将以gzip压缩JSON、JSON或HTML格式写入。
export_stacks(path, metric='self_cpu_time_total')
将堆栈跟踪保存到文件
参数
path (str)
- 将堆栈文件保存到此路径;metric (str)
- 使用的指标:“self_cpu_time_total” 或 “self_cuda_time_total”
key_averages(group_by_input_shape=False, group_by_stack_n=0, group_by_overload_name=False)
Averages events, grouping them by operator name and (optionally) input shapes, stack and overload name.
Note: To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.
preset_metadata_json(key, value)
在性能分析器未启动时预设用户自定义元数据,该元数据后续会被添加到跟踪文件中。
元数据格式为字符串键与有效JSON值的组合
toggle_collection_dynamic(enable, activities)
功能说明
可在收集过程中的任意时间点开启/关闭活动收集功能。当前支持切换 Torch 算子(CPU)以及 Kineto 中支持的 CUDA 活动。
参数说明
activities (iterable)
– 用于性能分析的活动组列表,支持以下取值:torch.profiler.ProfilerActivity.CPU
torch.profiler.ProfilerActivity.CUDA
使用示例
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]
) as p:code_to_profile_0()// turn off collection of all CUDA activityp.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])code_to_profile_1()// turn on collection of all CUDA activityp.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])code_to_profile_2()
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
class torch.profiler.profile(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, use_cuda=None, custom_trace_id_callback=None)
分析器上下文管理器。
参数
activities (iterable)
– 用于分析的活跃组列表(CPU、CUDA),支持以下值:
torch.profiler.ProfilerActivity.CPU
、torch.profiler.ProfilerActivity.CUDA
、torch.profiler.ProfilerActivity.XPU
。
默认值:ProfilerActivity.CPU 和(如果可用)ProfilerActivity.CUDA 或(如果可用)ProfilerActivity.XPU。
schedule (Callable)
– 可调用对象,接收步骤(int)作为单一参数,并返回
ProfilerAction
值,指定在每一步执行的分析器操作。
on_trace_ready (Callable)
– 当schedule
在分析过程中返回ProfilerAction.RECORD_AND_SAVE
时,每一步调用的可调用对象。record_shapes ([bool])
– 保存操作符输入形状的信息。profile_memory ([bool])
– 跟踪张量内存分配/释放。with_stack ([bool])
– 记录操作符的源信息(文件和行号)。with_flops ([bool])
– 使用公式估算特定操作符(矩阵乘法和2D卷积)的浮点运算次数(FLOPs)。with_modules ([bool])
– 记录与操作符调用堆栈对应的模块层次结构(包括函数名称)。例如,如果模块A的前向调用模块B的前向,其中包含一个aten::add操作符,那么aten::add的模块层次结构是A.B。
注意,目前此功能仅支持TorchScript模型,不支持eager模式模型。
experimental_config (_ExperimentalConfig)
– 一组用于Kineto库功能的实验性选项。注意,不保证向后兼容性。execution_trace_observer (ExecutionTraceObserver)
– 一个PyTorch执行跟踪观察器对象。
PyTorch执行跟踪 提供基于图的AI/ML工作负载表示,并支持重放基准测试、模拟器和仿真器。
当包含此参数时,观察器的start()和stop()将在与PyTorch分析器相同的时间窗口内调用。请参阅下面的示例部分获取代码示例。
acc_events ([bool])
– 启用跨多个分析周期的FunctionEvents累积。use_cuda ([bool])
–
自版本1.8.1起弃用:改用 activities
。
注意:使用 schedule()
生成可调用的计划。
非默认计划在分析长时间训练作业时非常有用,允许用户在训练过程的不同迭代中获取多个跟踪。
默认计划只是在上下文管理器持续时间内连续记录所有事件。
注意:使用 tensorboard_trace_handler()
生成TensorBoard的结果文件:
on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)
分析完成后,结果文件可以在指定目录中找到。使用命令:
tensorboard --logdir dir_name
在TensorBoard中查看结果。
更多信息,请参阅 PyTorch Profiler TensorBoard Plugin
注意:启用形状和堆栈跟踪会导致额外的开销。
当指定record_shapes=True时,分析器将临时持有对张量的引用;这可能会进一步阻止某些依赖于引用计数的优化,并引入额外的张量副本。
示例:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ]
) as p:code_to_profile()
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
使用性能分析器的 schedule
、on_trace_ready
和 step
函数:
# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record# the third and the forth iterations, # after which the trace will become available# and on_trace_ready (when set) is called;# the cycle repeats starting with the next stepschedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), on_trace_ready=trace_handler# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')# used when outputting for tensorboard) as p:for iter in range(N):code_iteration_to_profile(iter)# send a signal to the profiler that the next iteration has startedp.step()
以下示例展示了如何设置执行跟踪观察器(execution_trace_observer)
with torch.profiler.profile(...execution_trace_observer=(ExecutionTraceObserver().register_callback("./execution_trace.json")), ) as p:for iter in range(N):code_iteration_to_profile(iter)p.step()
你也可以参考 tests/profiler/test_profiler.py
中的 test_execution_trace_with_kineto()
方法。
注意:任何实现了 _ITraceObserver
接口的对象都可以传入使用。
get_trace_id()
返回当前跟踪ID。
set_custom_trace_id_callback(callback)
设置当生成新跟踪ID时要调用的回调函数。
step()
通知性能分析器下一个分析步骤已开始。
class torch.profiler.ProfilerAction(value)
可在指定时间间隔执行的性能分析器操作
class torch.profiler.ProfilerActivity
Members:
CPU
XPU
MTIA
CUDA
HPU
PrivateUse1
property name
torch.profiler.schedule(*, wait, warmup, active, repeat=0, skip_first=0, skip_first_wait=0)
返回一个可调用对象,可作为分析器的schedule
参数使用。该分析器会跳过前skip_first
个步骤,然后等待wait
个步骤,接着进行warmup
个步骤的热身,随后执行active
个步骤的活跃记录,最后从wait
步骤开始重复这个循环。
通过repeat
参数可指定可选的循环次数,值为零表示循环将持续到分析结束。
skip_first_wait
参数控制是否跳过第一个wait
阶段。
当用户希望在循环之间等待时间超过skip_first
但首次分析不适用时,这个功能很有用。例如,若skip_first
为10且wait
为20,当skip_first_wait
为零时,第一个循环将在热身前等待10+20=30个步骤;若skip_first_wait
非零,则仅等待10个步骤。之后所有循环都会在最后一次活跃记录和热身之间等待20个步骤。
返回类型:Callable
torch.profiler.tensorboard_trace_handler(dir_name, worker_name=None, use_gzip=False)
将跟踪文件输出到 dir_name
目录,该目录可直接作为日志目录传递给 TensorBoard。
在分布式场景中,每个工作节点的 worker_name
应保持唯一,默认会设置为 ‘[hostname]_[pid]’。
Intel 插桩与追踪技术 API
torch.profiler.itt.is_available()
检查 ITT 功能是否可用
torch.profiler.itt.mark(msg)
描述某个时间点发生的瞬时事件。
参数
msg (str)
– 与该事件关联的ASCII消息。
torch.profiler.itt.range_push(msg)
将一段范围压入嵌套范围跨度栈。返回所启动范围的从零开始的深度。
参数
msg (str)
– 与该范围关联的ASCII消息
torch.profiler.itt.range_pop()
从嵌套范围跨度栈中弹出一个范围。返回被结束范围的从零开始的深度。
torch.nn.init
警告:本模块中的所有函数均用于初始化神经网络参数,因此它们都在 torch.no_grad()
模式下运行,且不会被自动求导机制追踪。
torch.nn.init.calculate_gain(nonlinearity, param=None)
返回给定非线性函数的推荐增益值。
具体数值如下:
非线性函数 | 增益值 |
---|---|
Linear / Identity | 111 |
Conv{1,2,3}D | 111 |
Sigmoid | 111 |
Tanh | 53\frac{5}{3}35 |
ReLU | 2\sqrt{2}2 |
Leaky Relu | KaTeX parse error: Expected 'EOF', got '_' at position 34: … \text{negative_̲slope}^2}} |
SELU | 34\frac{3}{4}43 |
警告:为了实现自归一化神经网络,应该使用nonlinearity='linear'
而非nonlinearity='selu'
。这样可以使初始权重具有1/N
的方差,这对在前向传播中形成稳定不动点是必要的。
相比之下,SELU
的默认增益值牺牲了归一化效果,以在矩形层中获得更稳定的梯度流。
参数说明
nonlinearity
- 非线性函数(nn.functional名称)param
- 非线性函数的可选参数
使用示例
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
torch.nn.init.uniform_(tensor, a=0.0, b=1.0, generator=None)
使用均匀分布中的值填充输入张量。
U(a,b){U}(a, b)U(a,b).
参数
tensor ( Tensor )
– 一个n维的torch.Tensora (float)
– 均匀分布的下界b (float)
– 均匀分布的上界generator (Optional[ Generator ])
– 用于采样的torch生成器(默认值:None)
返回类型:Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)
使用正态分布中的随机值填充输入张量。
N(mean,std2)\mathcal{N}(\text{mean}, \text{std}^2)N(mean,std2)
参数说明
tensor ( Tensor )
– 一个n维的torch.Tensormean (float)
– 正态分布的均值std (float)
– 正态分布的标准差generator (Optional[ Generator ])
– 用于采样的torch生成器对象(默认值为None)
返回值类型:Tensor
使用示例
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
torch.nn.init.constant_(tensor, val)
将输入张量填充为值 val\text{val}val。
参数
tensor ( Tensor )
– 一个 n 维的 torch.Tensorval (float)
– 用于填充张量的值
返回类型 : Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
torch.nn.init.ones_(tensor)
填充输入张量(Tensor)为标量值1。
参数
tensor ( Tensor )
– 一个n维的torch.Tensor
返回类型:Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)
torch.nn.init.zeros_(tensor)
将输入张量填充为标量值0。
参数
tensor ( Tensor )
- 一个n维的torch.Tensor
返回类型: Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)
torch.nn.init.eye_(tensor)
将二维输入张量填充为单位矩阵。
在Linear层中保留输入的恒等性,尽可能多地保留输入特征。
参数
tensor
– 一个二维的torch.Tensor张量
示例
>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)
torch.nn.init.dirac_(tensor, groups=1)
用狄拉克δ函数填充{3, 4, 5}维输入张量。
在卷积层中保持输入数据的特性,尽可能多地保留输入通道。当组数大于1时,每组通道都会保持其特性。
参数
tensor
– 一个{3, 4, 5}维的torch.Tensorgroups (int, 可选)
– 卷积层中的组数(默认值:1)
示例
>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)
torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None)
使用Xavier均匀分布为输入张量填充数值。
该方法在论文《理解深度前馈神经网络训练难点》- Glorot, X. 和 Bengio, Y. (2010)中有详细描述。生成的张量将从U(−a,a)分布中采样,其中:
a = gain × √(6/(fan_in + fan_out))
该初始化方法也被称为Glorot初始化。
参数:
tensor (Tensor)
- 一个n维的torch.Tensorgain (float)
- 可选的缩放因子generator (Optional[Generator])
- 用于采样的torch生成器(默认值:None)
返回类型:Tensor
示例:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
注意:fan_in
和 fan_out
的计算基于权重矩阵以转置方式使用的假设(例如 Linear
层中的 x @ w.T
,其中 w.shape = [fan_out, fan_in]
)。
这对正确初始化至关重要。
如果计划使用 x @ w
(其中 w.shape = [fan_in, fan_out]
),请传入转置后的权重矩阵,即 nn.init.xavier_uniform_(w.T, ...)
。
torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None)
使用Xavier正态分布为输入张量填充数值。
该方法在论文《理解深度前馈神经网络训练的难点》- Glorot, X. 和 Bengio, Y. (2010) 中提出。生成的张量将从U(−a,a)\mathcal{U}(-a, a)U(−a,a)分布中采样数值,其中
KaTeX parse error: Expected 'EOF', got '_' at position 48: …ac{6}{\text{fan_̲in} + \text{fan…
参数
tensor ( Tensor )
– 一个n维的torch.Tensorgain (float)
– 可选的缩放因子generator (Optional[ Generator ])
– 用于采样的torch生成器(默认值:None)
返回类型:Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)
注意:fan_in
和 fan_out
的计算基于权重矩阵以转置方式使用的假设(例如 Linear
层中的 x @ w.T
,其中 w.shape = [fan_out, fan_in]
)。
这一点对于正确初始化至关重要。
如果你计划使用 x @ w
(其中 w.shape = [fan_in, fan_out]
),请传入转置后的权重矩阵,即 nn.init.xavier_normal_(w.T, ...)
。
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)
使用Kaiming均匀分布为输入张量填充数值。
该方法出自论文《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》——何恺明等人(2015年)。生成的张量数值将采样自U(−bound,bound)U(−bound,bound)U(−bound,bound)区间,其中边界值计算公式为:
KaTeX parse error: Expected 'EOF', got '_' at position 59: …ac{3}{\text{fan_̲mode}}}
该初始化方法也被称为He初始化。
参数说明
tensor ( Tensor )
– 一个n维的torch.Tensora (float)
– 该层后接整流器的负斜率(仅当使用'leaky_relu'
时生效)mode (str)
– 可选'fan_in'
(默认)或'fan_out'
。选择'fan_in'
可保持前向传播中权重的方差量级,选择'fan_out'
则保持反向传播中的量级nonlinearity (str)
– 非线性函数名称(需对应nn.functional中的名称),建议仅使用'relu'
或'leaky_relu'
(默认值)generator (Optional[ Generator ])
– 用于采样的torch生成器对象(默认值为None)
使用示例
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
注意:fan_in
和 fan_out
的计算基于权重矩阵以转置方式使用的假设(例如在 Linear
层中的 x @ w.T
,其中 w.shape = [fan_out, fan_in]
)。
这一点对于正确的初始化非常重要。
如果你计划使用 x @ w
(其中 w.shape = [fan_in, fan_out]
),请传入转置后的权重矩阵,即 nn.init.kaiming_uniform_(w.T, ...)
。
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)
使用Kaiming正态分布为输入张量填充数值。
该方法在论文《深入研究整流器:超越ImageNet分类的人类水平性能》- He, K. 等人 (2015) 中有详细描述。生成的张量值将从 N(0,std2)N(0,std^2)N(0,std2) 分布中采样,其中
KaTeX parse error: Expected 'EOF', got '_' at position 48: …\sqrt{\text{fan_̲mode}}}
该方法也被称为He初始化。
参数
tensor ( Tensor )
– 一个n维的torch.Tensora (float)
– 该层之后使用的整流器的负斜率(仅与'leaky_relu'
一起使用)mode (str)
– 可选'fan_in'
(默认)或'fan_out'
。选择'fan_in'
会保持前向传播中权重的方差大小,选择'fan_out'
则会保持反向传播中的大小。nonlinearity (str)
– 非线性函数(nn.functional名称),建议仅与'relu'
或'leaky_relu'
(默认)一起使用。generator (Optional[ Generator ])
– 用于采样的torch生成器(默认:None)
示例
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
注意:fan_in
和 fan_out
的计算基于权重矩阵以转置方式使用的假设(例如 Linear
层中的 x @ w.T
,其中 w.shape = [fan_out, fan_in]
)。
这对正确初始化至关重要。
如果计划使用 x @ w
(其中 w.shape = [fan_in, fan_out]
),请传入转置后的权重矩阵,即 nn.init.kaiming_normal_(w.T, ...)
。
torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0, generator=None)
使用截断正态分布生成的值填充输入张量。
这些值实际上是从正态分布 N(mean,std2)N(mean,std^2)N(mean,std2) 中抽取的,对于超出 [a,b][a,b][a,b] 范围的值会重新抽取,直到其落在边界内。当满足 a≤mean≤ba \leq mean \leq ba≤mean≤b 时,生成随机值的方法效果最佳。
参数
tensor ( Tensor )
– 一个n维的torch.Tensormean (float)
– 正态分布的均值std (float)
– 正态分布的标准差a (float)
– 最小截断值b (float)
– 最大截断值generator (Optional[ Generator ])
– 用于采样的torch生成器(默认值:None)
返回类型:Tensor
示例
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
torch.nn.init.orthogonal_(tensor, gain=1, generator=None)
用(半)正交矩阵填充输入张量。
该方法在《Exact solutions to the nonlinear dynamics of learning in deep linear neural networks》- Saxe, A. 等人 (2013) 的论文中有所描述。输入张量必须至少具有2个维度,对于超过2维的张量,其尾部维度会被展平。
参数
tensor
– 一个n维的 torch.Tensor,其中 $n \geq 2 $gain
– 可选缩放因子generator (Optional[ Generator ])
– 用于采样的 torch Generator(默认值:None)
示例
>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)
torch.nn.init.sparse_(tensor, sparsity, std=0.01, generator=None)
将二维输入张量填充为稀疏矩阵。
非零元素将从正态分布 N(0,0.01)N(0, 0.01)N(0,0.01) 中抽取,如《Deep learning via Hessian-free optimization》- Martens, J. (2010) 所述。
参数
tensor
– 一个 n 维的 torch.Tensorsparsity
– 每列中要置零的元素比例std
– 用于生成非零值的正态分布的标准差generator (Optional[ Generator ])
– 用于采样的 torch 生成器(默认值:None)
示例
>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)
torch.nn.attention
该模块包含用于改变 torch.nn.functional.scaled_dot_product_attention 行为的函数和类
工具集
sdpa_kernel | 用于选择缩放点积注意力计算后端的上下文管理器 |
---|---|
SDPBackend | 枚举类,包含缩放点积注意力计算的不同后端实现 |
子模块
flex_attention | 该模块实现了PyTorch中flex_attention的用户接口API。 |
---|---|
bias | 定义了与scaled_dot_product_attention配合使用的偏置子类 |
experimental |
2025-08-20(三)