当前位置: 首页 > java >正文

PyTorch API 5

文章目录

  • torch.compiler
    • 延伸阅读
  • torch.fft
    • 快速傅里叶变换
    • 辅助函数
  • torch.func
    • 什么是可组合的函数变换?
    • 为什么需要可组合的函数变换?
    • 延伸阅读
  • torch.futures
  • torch.fx
    • 概述
    • 编写转换函数
      • 图结构快速入门
      • 图操作
        • 直接操作计算图
        • 使用 replace_pattern() 进行子图重写
        • 图操作示例
      • 代理/回溯机制
      • 解释器模式
        • 解释器模式示例
    • 调试
      • 简介
      • 变换编写中的常见陷阱
      • 检查模块的正确性
      • 调试生成的代码
        • 使用 `pdb`
        • 打印生成的代码
        • 使用 `GraphModule` 中的 `to_folder` 函数
      • 调试转换过程
      • 可用的调试器
    • 符号追踪的局限性
      • 动态控制流
        • 静态控制流
      • 非`torch`函数
      • 使用 `Tracer` 类自定义追踪功能
        • 叶子模块
      • 杂项说明
    • API 参考
  • torch.fx.experimental
    • torch.fx.experimental.symbolic_shapes
    • torch.fx.experimental.proxy_tensor
  • torch.hub
    • 发布模型
      • 如何实现入口点?
      • 重要通知
    • 从Hub加载模型
      • 运行加载的模型:
      • 下载的模型保存在哪里?
      • 缓存逻辑
      • 已知限制:
  • TorchScript
    • 创建 TorchScript 代码
    • 混合使用追踪与脚本化
    • TorchScript 语言
    • 内置函数与模块
      • PyTorch 函数与模块
      • Python 函数与模块
      • Python 语言参考对比
    • 调试
      • 禁用 JIT 进行调试
      • 代码检查
      • 解读图结构
      • 追踪器
        • 追踪边界情况
        • 自动追踪检查
        • 追踪器警告
    • 常见问题解答
    • 已知问题
    • 附录
      • 迁移至 PyTorch 1.2 递归脚本化 API
        • 模块
        • 函数
        • TorchScript 类
        • 属性
        • 常量
        • 变量
      • 融合后端
      • 参考资料
  • torch.linalg
    • 矩阵属性
    • 矩阵分解
    • 求解器
    • 逆矩阵
    • 矩阵函数
    • 矩阵运算
    • 张量运算
    • 杂项函数
    • 实验性函数
  • torch.monitor
    • API 参考
  • torch.signal 模块
    • torch.signal.windows 窗口函数
  • torch.special
    • 函数
  • torch.overrides
    • 函数
  • torch.package
    • 教程
      • 打包你的第一个模型
    • 如何实现...
      • 查看包内包含哪些内容?
        • 将包视为ZIP归档文件处理
        • 使用 `file_structure()` API
      • 查看某个模块为何被列为依赖项?
      • 如何在打包时包含任意资源并后续访问?
      • 自定义类的打包方式
      • 如何在源码中检测当前是否运行在包环境中?
      • 如何将代码补丁打入包中?
      • 如何从打包代码中访问包内容?
      • 区分打包代码与非打包代码
      • 如何重新导出已导入的对象?
      • 如何打包 TorchScript 模块?
    • 说明
      • `torch.package` 格式概述
        • 框架文件
        • 用户文件
      • `torch.package` 如何查找代码依赖项
        • 分析对象的依赖关系
        • 分析模块依赖关系
      • 依赖管理
        • `intern`
        • `extern`
        • `mock`
        • 代码重构
        • 模式
      • `torch.package` 的注意事项
        • 避免在模块中使用全局状态
        • 类型在包与加载环境之间不共享
      • `torch.package` 如何实现包之间的隔离
        • 名称修饰(Mangling)
    • API 参考
  • torch.profiler
    • 概述
    • API 参考
    • Intel 插桩与追踪技术 API
  • torch.nn.init
  • torch.nn.attention
    • 工具集
    • 子模块


torch.compiler

torch.compiler 是一个命名空间,通过它向用户开放了一些内部编译器方法。该命名空间中的主要功能和特性是 torch.compile

torch.compile 是 PyTorch 2.x 引入的一个函数,旨在解决 PyTorch 中精确图捕获的问题,最终帮助软件工程师加速运行他们的 PyTorch 程序。torch.compile 使用 Python 编写,标志着 PyTorch 从 C++ 向 Python 的过渡。

torch.compile 利用了以下底层技术:

  • TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的 Frame Evaluation API 功能来安全捕获 PyTorch 计算图。通过 torch.compiler 命名空间向 PyTorch 用户开放可用方法。
  • TorchInductortorch.compile 默认的深度学习编译器,为多种加速器和后端生成快速代码。需要通过后端编译器才能实现 torch.compile 的加速效果。对于 NVIDIA、AMD 和 Intel GPU,它使用 OpenAI Triton 作为关键构建块。
  • AOT Autograd 不仅能捕获用户级代码,还能捕获反向传播,实现"提前"捕获反向传递。这使得 TorchInductor 能够同时加速前向和反向传递。

注意:在本文档中,术语 torch.compile、TorchDynamo 和 torch.compiler 有时会互换使用。

如上所述,要通过 TorchDynamo 运行更快的工作流,torch.compile 需要一个后端将捕获的计算图转换为快速机器码。不同的后端会带来不同的优化效果。默认后端是 TorchInductor(也称为 inductor)。TorchDynamo 还支持由合作伙伴开发的一系列后端,可以通过运行 torch.compiler.list_backends() 查看,每个后端都有其可选依赖项。

一些最常用的后端包括:

训练和推理后端

后端描述
torch.compile(m, backend="inductor")使用 TorchInductor 后端。了解更多
torch.compile(m, backend="cudagraphs")使用 AOT Autograd 的 CUDA 图。了解更多
torch.compile(m, backend="ipex")在 CPU 上使用 IPEX。了解更多
torch.compile(m, backend="onnxrt")使用 ONNX Runtime 在 CPU/GPU 上进行训练。了解更多

仅推理后端

后端描述
torch.compile(m, backend="tensorrt")使用 Torch-TensorRT 进行推理优化。需要在调用脚本中 import torch_tensorrt 来注册后端。了解更多
torch.compile(m, backend="ipex")在 CPU 上使用 IPEX 进行推理。了解更多
torch.compile(m, backend="tvm")使用 Apache TVM 进行推理优化。了解更多
torch.compile(m, backend="openvino")使用 OpenVINO 进行推理优化。了解更多

延伸阅读

PyTorch 用户入门指南

  • 快速入门
  • torch.compiler API 参考
  • torch.compiler.config 配置
  • TorchDynamo 细粒度追踪 API
  • AOTInductor: Torch.Export 模型的预编译方案
  • TorchInductor GPU 性能分析
  • torch.compile 性能剖析指南
  • 常见问题解答
  • torch.compile 故障排查
  • PyTorch 2.0 性能看板

PyTorch 开发者深度解析

  • Dynamo 架构概览
  • Dynamo 技术深潜
  • 动态形状支持
  • PyTorch 2.0 NNModule 支持
  • 后端开发最佳实践
  • CUDA 图树优化
  • 伪张量机制

PyTorch 后端供应商指南

  • 自定义后端开发
  • ATen IR 图转换开发
  • 中间表示层详解


torch.fft

离散傅里叶变换及相关函数。


快速傅里叶变换

fft计算input的一维离散傅里叶变换
ifft计算input的一维离散傅里叶逆变换
fft2计算input的二维离散傅里叶变换
ifft2计算input的二维离散傅里叶逆变换
fftn计算input的N维离散傅里叶变换
ifftn计算input的N维离散傅里叶逆变换
rfft计算实数input的一维傅里叶变换
irfft计算rfft()的逆变换
rfft2计算实数input的二维离散傅里叶变换
irfft2计算rfft2()的逆变换
rfftn计算实数input的N维离散傅里叶变换
irfftn计算rfftn()的逆变换
hfft计算Hermitian对称input信号的一维离散傅里叶变换
ihfft计算hfft()的逆变换
hfft2计算Hermitian对称input信号的二维离散傅里叶变换
ihfft2计算实数input的二维离散傅里叶逆变换
hfftn计算Hermitian对称input信号的N维离散傅里叶变换
ihfftn计算实数input的N维离散傅里叶逆变换

辅助函数

fftfreq计算大小为 n 的信号的离散傅里叶变换采样频率。
rfftfreq计算大小为 n 的信号在使用 rfft() 时的采样频率。
fftshift对由 fftn() 提供的 n 维 FFT 数据进行重新排序,使负频率项优先。
ifftshiftfftshift() 的逆操作。


torch.func

torch.func(前身为"functorch")是为PyTorch提供的JAX风格可组合函数变换工具。


注意:该库目前处于测试阶段。
这意味着这些功能基本可用(除非另有说明),且我们(PyTorch团队)将持续推进该库的发展。但API可能会根据用户反馈进行调整,且尚未完全覆盖所有PyTorch操作。

如果您对API有改进建议,或希望支持特定使用场景,请提交GitHub issue或直接联系我们。我们非常期待了解您如何使用这个库。


什么是可组合的函数变换?

  • 函数变换是一种高阶函数,它接受一个数值函数作为输入,并返回一个新函数来计算不同的量。
  • torch.func 提供了自动微分变换(例如 grad(f) 返回计算 f 梯度的函数)、向量化/批处理变换(例如 vmap(f) 返回对输入批次执行 f 的函数)等多种变换。
  • 这些函数变换可以任意组合使用。例如,组合 vmap(grad(f)) 可以计算单样本梯度(per-sample-gradients),这是当前标准 PyTorch 无法高效计算的量。

为什么需要可组合的函数变换?

目前在 PyTorch 中实现以下用例较为棘手:

  • 计算逐样本梯度(或其他逐样本量)
  • 在单台机器上运行模型集成
  • 在 MAML 内循环中高效批处理任务
  • 高效计算雅可比矩阵和海森矩阵
  • 高效计算批量雅可比矩阵和海森矩阵

通过组合使用 vmap()grad()vjp() 变换,我们无需为每个用例单独设计子系统即可实现上述功能。这种可组合函数变换的理念源自 JAX 框架。


延伸阅读

  • torch.func 快速指南
    • 什么是 torch.func?
    • 为什么需要可组合函数变换?
    • 有哪些变换方法?
  • torch.func API 参考
    • 函数变换
    • torch.nn.Module 工具集
    • 调试工具
  • 使用限制
    • 通用限制
    • torch.autograd API
    • vmap 限制
    • 随机性控制
  • 从 functorch 迁移到 torch.func
    • 函数变换
    • 神经网络模块工具
    • functorch.compile


torch.futures

该包提供了一种 Future 类型,用于封装异步执行过程,并提供一组实用函数来简化对 Future 对象的操作。目前,Future 类型主要被 分布式RPC框架 使用。


class torch.futures.Future(*, devices=None) 

Wrapper around a torch._C.Future which encapsulates an asynchronous
execution of a callable, e.g. rpc_async(). It also exposes a set of APIs to add callback functions and set results.


Warning: GPU support is a beta feature, subject to changes.


add_done_callback(callback)

将给定的回调函数附加到此Future上,该回调函数将在Future完成时运行。可以向同一个Future添加多个回调,但无法保证它们的执行顺序。回调函数必须接受一个参数,即对此Future的引用。回调函数可以使用value()方法获取值。请注意,如果此Future已经完成,给定的回调将立即内联执行。

我们建议使用then()方法,因为它提供了一种在回调完成后进行同步的方式。如果回调不返回任何内容,add_done_callback可能更高效。但then()add_done_callback在底层使用相同的回调注册API。

对于GPU张量,此方法的行为与then()相同。

参数

  • callback (Future) – 一个可调用对象,接受一个参数,即对此Future的引用。

注意:请注意,如果回调函数抛出异常,无论是由于原始future以异常完成并调用fut.wait(),还是由于回调中的其他代码,都必须仔细处理错误。例如,如果此回调随后完成了其他future,这些future不会被标记为以错误完成,用户需要独立处理这些future的完成/等待。


示例

>>> def callback(fut):
...     print("This will run after the future has finished.")
...     print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
This will run after the future has finished.
5

done()

如果该Future已完成则返回True。当Future包含结果或异常时即视为完成。

如果值包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,Future.done()仍会返回True,因为在此阶段结果已可被使用(前提是执行适当的同步操作,参见wait())。

返回类型:bool


set_exception(result)

为这个 Future 设置一个异常,这将标记该 Future 以错误状态完成,并触发所有已附加的回调。请注意,当对此 Future 调用 wait()/value() 时,此处设置的异常

将被内联抛出。

参数

  • result ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException "(in Python v3.13)")) – 该 Future 的异常对象。

示例

>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
Traceback (most recent call last):
...
ValueError: foo

set_result(result)

为这个Future设置结果,这将标记该Future为已完成状态并触发所有关联的回调。需要注意的是,一个Future不能被标记为已完成两次。

如果结果包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,只要调用此方法时这些内核所入队的流被设置为当前流,仍可调用此方法。简而言之,在启动这些内核后立即调用此方法是安全的,无需额外同步,前提是期间不切换流。此方法会在所有相关当前流上记录事件,并利用它们确保此Future的所有消费者都能得到正确调度。

参数

  • result ( object ) - 该Future的结果对象。

示例:


>>> import threading
>>> import time
>>> def slow_set_future(fut, value):
...     time.sleep(0.5)
...     fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
...     target=slow_set_future, 
...     args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()

then(callback)

将给定的回调函数附加到此Future上,该回调函数将在Future完成时运行。可以向同一个Future添加多个回调,但无法保证它们的执行顺序(如需确保特定顺序,请考虑链式调用:fut.then(cb1).then(cb2))。回调函数必须接受一个参数,即对此Future的引用。回调函数可通过value()方法获取值。请注意,如果此Future已完成,给定的回调将立即内联执行。

如果Future的值包含位于GPU上的张量,回调可能在填充这些张量的异步内核尚未在设备上完成执行时就被调用。不过,回调将通过设置为当前的一些专用流(从全局池中获取)被调用,这些流将与那些内核同步。因此,回调对这些张量执行的任何操作都将在内核完成后调度到设备上。换句话说,只要回调不切换流,它就可以安全地操作结果而无需额外同步。这与wait()的非阻塞行为类似。

类似地,如果回调返回的值包含位于GPU上的张量,即使生成这些张量的内核仍在设备上运行,回调也可以这样做,前提是回调在执行期间没有切换流。如果想要切换流,必须注意与原始流重新同步,即回调被调用时当前的流。

参数

  • callback (Callable) – 一个以该Future为唯一参数的可调用对象。

返回

一个新的Future对象,它持有callback的返回值,并将在给定callback完成时标记为已完成。

返回类型

Future[S]

注意:请注意,如果回调函数抛出异常,无论是通过原始future以异常完成并调用fut.wait(),还是通过回调中的其他代码,then返回的future将适当地标记为遇到错误。但是,如果此回调随后完成其他future,这些future不会标记为以错误完成,用户需负责独立处理这些future的完成/等待。


示例

>>> def callback(fut):
...     print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
...     lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
RPC return value is 5、Chained cb done. None

value()

获取已完成的Future对象的值。

此方法仅应在调用wait()完成后,或在传递给then()的回调函数内部使用。其他情况下,该Future可能尚未持有值,调用value()可能会失败。

如果值包含位于GPU上的张量,此方法将不会执行任何额外的同步操作。此类同步应事先通过调用wait()单独完成(回调函数内部除外,因为then()已自动处理此情况)。

返回值
Future持有的值。如果创建该值的函数(回调或RPC)抛出错误,此value()方法同样会抛出错误。

返回类型:T


wait()

等待直到该 Future 的值准备就绪。

如果值包含位于 GPU 上的张量,则会与设备上异步填充这些张量的内核执行额外的同步操作。此类同步是非阻塞的,这意味着 wait() 会在当前流中插入必要的指令,以确保后续在这些流上排队的操作能正确安排在异步内核之后执行。但一旦完成指令插入,即使这些内核仍在运行,wait() 也会立即返回。只要不切换流,在访问和使用这些值时无需进一步同步。

返回值:此 Future 持有的值。如果创建该值的函数(回调或 RPC)抛出错误,此 wait 方法同样会抛出错误。

返回类型:T


torch.futures.collect_all(futures)

将提供的 Future 对象收集到一个统一的组合 Future 中,该组合 Future 会在所有子 Future 完成时完成。

参数

  • futures (list) – 一个包含 Future 对象的列表。

返回

返回一个 Future 对象,该对象关联到传入的 Future 列表。

返回类型

Future[list [torch.jit.Future]]


示例

>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1])
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
fut1 result = 1

torch.futures.wait_all(futures)

等待所有提供的 futures 完成,并返回已完成值的列表。如果任一 future 遇到错误,该方法将提前退出并报告错误,而不会等待其他 futures 完成。

参数

  • futures (list) – 一个 Future 对象列表。

返回值:已完成 Future 结果的列表。如果对任何 Future 调用 wait 时抛出错误,该方法也会抛出错误。

返回类型:list



torch.fx


概述

FX 是一个供开发者使用的工具包,用于转换 nn.Module 实例。FX 包含三个核心组件:符号追踪器中间表示Python 代码生成。以下是这些组件的实际应用演示:

import torch# Simple module for demonstration
class MyModule(torch.nn.Module):def __init__(self) -None:super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():%x : [num_users=1] = placeholder[target=x]%param : [num_users=1] = get_attr[target=param]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param;  x = param = Nonelinear = self.linear(add);  add = Noneclamp = linear.clamp(min = 0.0, max = 1.0);  linear = Nonereturn clamp
"""

符号追踪器(symbolic tracer)对Python代码执行"符号执行"。它通过代码传递称为Proxy的虚拟值,并记录对这些Proxy的操作。有关符号追踪的更多信息,请参阅symbolic_trace()Tracer文档。

中间表示(intermediate representation)是符号追踪过程中记录操作的容器。它由一组节点组成,这些节点表示函数输入、调用点(指向函数、方法或torch.nn.Module实例)以及返回值。有关IR的更多信息,请参阅Graph文档。IR是应用转换的基础格式。

Python代码生成功能使FX成为Python到Python(或Module到Module)的转换工具包。对于每个Graph IR,我们都可以生成符合Graph语义的有效Python代码。这个功能被封装在GraphModule中,它是一个torch.nn.Module实例,包含一个Graph以及从Graph生成的forward方法。

这些组件(符号追踪→中间表示→转换→Python代码生成)共同构成了FX的Python到Python转换流程。此外,这些组件也可以单独使用。例如,符号追踪可以单独用于捕获代码形式进行分析(而非转换)目的。代码生成可以用于通过编程方式生成模型,例如从配置文件生成。FX有许多用途!

在示例库中可以找到几个转换示例。


编写转换函数

什么是FX转换?本质上,它是一个形如下列的函数。


import torch
import torch.fxdef transform(m: nn.Module,        tracer_class : type = torch.fx.Tracer) -torch.nn.Module:# Step 1: Acquire a Graph representing the code in `m`# NOTE: torch.fx.symbolic_trace is a wrapper around a call to     # fx.Tracer.trace and constructing a GraphModule. We'll# split that out in our transform to allow the caller to     # customize tracing behavior.graph : torch.fx.Graph = tracer_class().trace(m)# Step 2: Modify this Graph or create a new onegraph = ...# Step 3: Construct a Module to returnreturn torch.fx.GraphModule(m, graph)

您的转换器将接收一个 torch.nn.Module,从中获取 Graph,进行一些修改后返回一个新的 torch.nn.Module。您应该将 FX 转换器返回的 torch.nn.Module 视为与常规 torch.nn.Module 完全相同——可以将其传递给另一个 FX 转换器、传递给 TorchScript 或直接运行它。确保 FX 转换器的输入和输出均为 torch.nn.Module 将有助于实现组合性。

注意:也可以直接修改现有的 GraphModule 而不创建新实例,例如:

import torch
import torch.fxdef transform(m : nn.Module) -nn.Module:gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)# Modify gm.graph# <...># Recompile the forward() method of `gm` from its Graphgm.recompile()return gm

请注意,你必须调用 GraphModule.recompile() 方法,使生成的 forward() 方法与修改后的 Graph 保持同步。

假设你已经传入了一个经过追踪转换为 Graphtorch.nn.Module,现在主要有两种方法来构建新的 Graph


图结构快速入门

关于图的语义完整说明可以参考 Graph 文档,这里我们主要介绍基础概念。Graph 是一种数据结构,用于表示 GraphModule 上的方法。其核心需要描述以下信息:

  • 方法的输入参数是什么?
  • 方法内部运行了哪些操作?
  • 方法的输出(即返回值)是什么?

这三个概念都通过 Node 实例来表示。下面通过一个简单示例来说明:

import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)gm.graph.print_tabular()

这里我们定义一个演示用的模块 MyModule,实例化后进行符号追踪,然后调用 Graph.print_tabular() 方法打印该 Graph 的节点表格:

操作码名称目标参数关键字参数
placeholderxx(){}
get_attrlinear_weightlinear.weight(){}
call_functionadd_1<built-in function add(x, linear_weight){}
call_modulelinear_1linear(add_1,){}
call_methodrelu_1relu(linear_1,){}
call_functionsum_1<built-in method sum …(relu_1,){‘dim’: -1}
call_functiontopk_1<built-in method topk …(sum_1, 3){}
outputoutputoutput(topk_1,){}

通过这些信息,我们可以回答之前提出的问题:

  • 方法的输入是什么?
    在FX中,方法输入通过特殊的 placeholder 节点指定。本例中有一个目标为 xplaceholder 节点,表示存在一个名为x的(非self)参数。
  • 方法内部有哪些操作?
    get_attrcall_functioncall_modulecall_method 节点表示方法中的操作。这些节点的完整语义说明可参考 Node 文档。
  • 方法的返回值是什么?
    Graph 中,返回值由特殊的 output 节点指定。

现在我们已经了解FX中代码表示的基本原理,接下来可以探索如何编辑 Graph


图操作


直接操作计算图

构建新Graph的一种方法是直接操作原有计算图。为此,我们可以简单地获取通过符号追踪得到的Graph并进行修改。例如,假设我们需要将所有torch.add()调用替换为torch.mul()调用。


import torch
import torch.fx# Sample module
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)def transform(m: torch.nn.Module,        tracer_class : type = fx.Tracer) -torch.nn.Module:graph : fx.Graph = tracer_class().trace(m)# FX represents its Graph as an ordered list of     # nodes, so we can iterate through them.for node in graph.nodes:# Checks if we're calling a function (i.e:# torch.add)if node.op == 'call_function':# The target attribute is the function# that call_function calls.if node.target == torch.add:node.target = torch.mulgraph.lint() # Does some checks to make sure the                  # Graph is well-formed.return fx.GraphModule(m, graph)

我们还可以进行更复杂的 Graph 重写操作,例如删除或追加节点。为了辅助这些转换,FX 提供了一些用于操作计算图的实用函数,这些函数可以在 Graph 文档中找到。

下面展示了一个使用这些 API 追加 torch.relu() 调用的示例。


# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node):# Insert a new `call_function` node calling `torch.relu`new_node = traced.graph.call_function(torch.relu, args=(node,))# We want all places that used the value of `node` to     # now use that value after the `relu` call we've added.# We use the `replace_all_uses_with` API to do this.node.replace_all_uses_with(new_node)

对于仅包含替换操作的简单转换,您也可以使用子图重写器。


使用 replace_pattern() 进行子图重写

FX 在直接图操作的基础上提供了更高层次的自动化能力。replace_pattern() API 本质上是一个用于编辑 Graph 的"查找/替换"工具。它允许你指定一个 pattern(模式)和 replacement(替换)函数,然后会追踪这些函数,在图中找到与 pattern 图匹配的操作组实例,并用 replacement 图的副本替换这些实例。这可以极大地自动化繁琐的图操作代码,随着转换逻辑变得复杂,手动操作会变得难以维护。


图操作示例
  • 替换单个操作符
  • 卷积/批量归一化融合
  • replace_pattern:基础用法
  • 量化
  • 逆变换

代理/回溯机制

另一种操作 Graph 的方式是复用符号追踪中使用的 Proxy 机制。例如,假设我们需要编写一个将 PyTorch 函数分解为更小操作的转换器:将每个 F.relu(x) 调用转换为 (x > 0) * x。传统做法可能是通过图重写来插入比较和乘法操作,然后清理原始的 F.relu。但借助 Proxy 对象,我们可以自动将操作记录到 Graph 中来实现这一过程。

具体实现时,只需将需要插入的操作写成常规 PyTorch 代码,并用 Proxy 对象作为参数调用该代码。这些 Proxy 对象会捕获对其执行的操作,并将其追加到 Graph 中。


# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):return (x 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose(model: torch.nn.Module,        tracer_class : type = fx.Tracer) -torch.nn.Module:"""Decompose `model` into smaller constituent operations.Currently,this only supports decomposing ReLU into itsmathematical definition: (x 0) * x"""graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()env = {}tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)for node in graph.nodes:if node.op == 'call_function' and node.target in decomposition_rules:# By wrapping the arguments with proxies,      # we can dispatch to the appropriate# decomposition rule and implicitly add it# to the Graph by symbolically tracing it.proxy_args = [fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]output_proxy = decomposition_rules[node.target](proxy_args)# Operations on `Proxy` always yield new `Proxy`s, and the             # return value of our decomposition rule is no exception.# We need to extract the underlying `Node` from the `Proxy`# to use it in subsequent iterations of this transform.new_node = output_proxy.nodeenv[node.name] = new_nodeelse:# Default case: we don't have a decomposition rule for this             # node, so just copy the node over into the new graph.new_node = new_graph.node_copy(node, lambda x: env[x.name])env[node.name] = new_nodereturn fx.GraphModule(model, new_graph)

除了避免显式的图操作外,使用Proxy还允许您将重写规则指定为原生Python代码。对于需要大量重写规则的转换(如vmap或grad),这通常可以提高规则的可读性和可维护性。

需要注意的是,在调用Proxy时,我们还传递了一个指向底层变量图的追踪器。这样做是为了防止当图中的操作是n元操作时(例如add是二元运算符),调用Proxy不会创建多个图追踪器实例,否则可能导致意外的运行时错误。特别是在底层操作不能安全地假设为一元操作时,我们推荐使用这种Proxy方法。

一个使用Proxy进行Graph操作的实际示例可以在这里找到。


解释器模式

在FX中,一个实用的代码组织模式是遍历Graph中的所有Node并执行它们。这种模式可用于多种场景,包括:

  • 运行时分析流经计算图的值
  • 通过Proxy重新追踪来实现代码转换

例如,假设我们想运行一个GraphModule,并在运行时记录节点上torch.Tensor的形状和数据类型属性。实现代码可能如下:

import torch
import torch.fx
from torch.fx.node import Nodefrom typing import Dictclass ShapeProp:"""Shape propagation. This class takes a `GraphModule`.Then, its `propagate` method executes the `GraphModule`node-by-node with the given arguments. As each operationexecutes, the ShapeProp class stores away the shape and     element type for the output values of each operation on     the `shape` and `dtype` attributes of the operation's`Node`."""def __init__(self, mod):self.mod = modself.graph = mod.graphself.modules = dict(self.mod.named_modules())def propagate(self, args):args_iter = iter(args)env : Dict[str, Node] = {}def load_arg(a):return torch.fx.graph.map_arg(a, lambda n: env[n.name])def fetch_attr(target : str):target_atoms = target.split('.')attr_itr = self.modfor i, atom in enumerate(target_atoms):if not hasattr(attr_itr, atom):raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")attr_itr = getattr(attr_itr, atom)return attr_itrfor node in self.graph.nodes:if node.op == 'placeholder':result = next(args_iter)elif node.op == 'get_attr':result = fetch_attr(node.target)elif node.op == 'call_function':result = node.target(load_arg(node.args), *load_arg(node.kwargs))elif node.op == 'call_method':self_obj, args = load_arg(node.args)kwargs = load_arg(node.kwargs)result = getattr(self_obj, node.target)(args, *kwargs)elif node.op == 'call_module':result = self.modules[node.target](load_arg(node.args), *load_arg(node.kwargs))# This is the only code specific to shape propagation.# you can delete this `if` branch and this becomes# a generic GraphModule interpreter.if isinstance(result, torch.Tensor):node.shape = result.shapenode.dtype = result.dtypeenv[node.name] = resultreturn load_arg(self.graph.result)

如你所见,为FX实现一个完整的解释器并不复杂,但却非常实用。为了简化这一模式的使用,我们提供了Interpreter类,它封装了上述逻辑,允许通过方法重写来覆盖解释器执行的某些方面。

除了执行操作外,我们还可以通过向解释器传递Proxy值来生成新的计算图。

类似地,我们提供了Transformer类来封装这种模式。Transformer的行为与Interpreter类似,但不同于调用run方法从模块获取具体输出值,你需要调用Transformer.transform()方法来返回一个新的GraphModule,该模块会应用你通过重写方法设置的任何转换规则。


解释器模式示例
  • 形状传播
  • 性能分析器

调试


简介

在编写转换代码的过程中,我们的代码往往不会一开始就完全正确。这时就需要进行调试。关键在于采用逆向思维:首先检查调用生成模块的结果,验证其正确性;接着审查并调试生成的代码;最后追溯导致生成代码的转换过程并进行调试。

如果您不熟悉调试工具,请参阅辅助章节可用调试工具。


变换编写中的常见陷阱

  • set迭代顺序的不确定性。在Python中,set数据类型是无序的。例如,使用set来存储Node等对象集合可能导致意外的非确定性行为。比如当迭代一组Node并将其插入Graph时,由于set数据类型是无序的,输出程序中操作的顺序将是非确定性的,且每次程序调用都可能变化。

推荐的替代方案是使用dict数据类型。自Python 3.7起(以及cPython 3.6起),dict保持了插入顺序。通过将需要去重的值存储在dict的键中,可以等效地实现set的功能。


检查模块的正确性

由于大多数深度学习模块的输出都是浮点型 torch.Tensor 实例,因此检查两个 torch.nn.Module 的结果是否相等并不像简单的相等性检查那样直接。为了说明这一点,我们来看一个示例:

import torch
import torch.fx
import torchvision.models as modelsdef transform(m : torch.nn.Module) -torch.nn.Module:gm = torch.fx.symbolic_trace(m)# Imagine we're doing some transforms here# <...>gm.recompile()return gmresnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)input_image = torch.randn(5, 3, 224, 224)assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用==相等运算符来检查两个深度学习模型的值是否相等。然而,这种做法存在两个问题:首先,该运算符返回的是张量而非布尔值;其次,浮点数值的比较应考虑误差范围(或epsilon),以解决浮点运算不可交换性的问题(详见此处)。

我们可以改用torch.allclose()函数,它会基于相对和绝对容差阈值进行近似比较:

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中的第一个工具,用于检查转换后的模块与参考实现相比是否按预期运行。


调试生成的代码

由于 FX 在 GraphModule 上生成 forward() 函数,使用传统的调试技术(如 print 语句或 pdb)会不太直观。幸运的是,我们有多种方法可以用来调试生成的代码。


使用 pdb

通过调用 pdb 可以进入正在运行的程序进行调试。虽然表示 Graph 的代码不在任何源文件中,但当执行前向传播时,我们仍然可以手动使用 pdb 进入该代码进行调试。


import torch
import torch.fx
import torchvision.models as modelsdef my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:graph = tracer_class().trace(inp)# Transformation logic here# <...># Return new Modulereturn fx.GraphModule(inp, graph)my_module = models.resnet18()
my_module_transformed = my_pass(my_module)input_value = torch.randn(5, 3, 224, 224)# When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line
import pdb; pdb.set_trace()my_module_transformed(input_value)

打印生成的代码

如果需要多次运行相同的代码,使用pdb逐步调试到目标代码可能会有些繁琐。这种情况下,一个简单的方法是将生成的forward传递代码直接复制粘贴到你的代码中,然后在那里进行检查。


# Assume that `traced` is a GraphModule that has undergone some
# number of transforms# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):x = self.xadd_1 = x + y;  x = y = Nonereturn add_1
"""# Subclass the original Module
class SubclassM(M):def __init__(self):super().__init__()# Paste the generated `forward` function (the one we printed and     # copied above) heredef forward(self, y):x = self.xadd_1 = x + y;  x = y = Nonereturn add_1# Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()

使用 GraphModule 中的 to_folder 函数

GraphModule.to_folder()GraphModule 中的一个方法,它允许你将生成的 FX 代码导出到一个文件夹。虽然像打印生成的代码中那样直接复制前向传播代码通常已经足够,但使用 to_folder 可以更方便地检查模块和参数。


m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看foo/module.py中的代码,并根据需要进行修改(例如添加print语句或使用pdb)来调试生成的代码。


调试转换过程

既然我们已经确认是转换过程生成了错误代码,现在就该调试转换本身了。首先,我们会查阅文档中的符号追踪限制部分。在确认追踪功能按预期工作后,我们的目标就转变为找出GraphModule转换过程中出现的问题。编写转换部分可能有快速解决方案,如果没有的话,我们还可以通过多种方式来检查追踪模块:

# Sample Module
class M(torch.nn.Module):def forward(self, x, y):return x + y# Create an instance of `M`
m = M()# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity.
traced = symbolic_trace(m)# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):add = x + y;  x = y = Nonereturn add
"""# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():%x : [num_users=1] = placeholder[target=x]%y : [num_users=1] = placeholder[target=y]%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})return add
"""# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add (x, y)  {}
output         output  output                   (add,)  {}
"""

通过使用上述工具函数,我们可以对比应用转换前后的追踪模块。有时,简单的视觉对比就足以定位错误。如果问题仍不明确,下一步可以尝试使用 pdb 这类调试器。

以上述示例为基础,请看以下代码:

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:# Get the Graph from our traced Moduleg = tracer_class().trace(module)"""Transformations on `g` go here"""return fx.GraphModule(module, g)# Transform the Graph
transformed = transform_graph(traced)# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

以上述示例为例,假设调用print(traced)时发现转换过程中存在错误。我们需要通过调试器定位问题根源。启动pdb调试会话后,可以在transform_graph(traced)处设置断点,然后按s键"步入"该函数调用,实时观察转换过程。

另一个有效方法是修改print_tabular方法,使其输出图中节点的不同属性(例如查看节点的input_nodesusers关系)。


可用的调试器

最常用的Python调试器是pdb。你可以通过在命令行输入python -m pdb FILENAME.py来以"调试模式"启动程序,其中FILENAME是你要调试的文件名。之后,你可以使用pdb的调试器命令逐步执行正在运行的程序。通常的做法是在启动pdb时设置一个断点(b LINE-NUMBER),然后调用c让程序运行到该断点处。这样可以避免你不得不使用sn逐行执行代码才能到达想要检查的部分。或者,你也可以在想中断的代码行前写入import pdb; pdb.set_trace()。如果添加了pdb.set_trace(),当你运行程序时它会自动进入调试模式(换句话说,你只需在命令行输入python FILENAME.py而不用输入python -m pdb FILENAME.py)。一旦以调试模式运行文件,你就可以使用特定命令逐步执行代码并检查程序的内部状态。网上有很多关于pdb的优秀教程,包括RealPython的《Python Debugging With Pdb》。

像PyCharm或VSCode这样的IDE通常内置了调试器。在你的IDE中,你可以选择:a)通过调出IDE中的终端窗口(例如在VSCode中选择View → Terminal)使用pdb,或者b)使用内置的调试器(通常是pdb的图形化封装)。


符号追踪的局限性

FX 采用符号追踪(又称符号执行)系统,以可转换/可分析的形式捕获程序语义。该系统具有以下特点:

  • 追踪性:通过实际执行程序(实际是torch.nn.Module或函数)来记录操作
  • 符号性:执行过程中流经程序的数据并非真实数据,而是符号(FX术语中称为Proxy)

虽然符号追踪适用于大多数神经网络代码,但它仍存在一些局限性。


动态控制流

符号追踪的主要局限在于目前不支持动态控制流。也就是说,当循环或if语句的条件可能依赖于程序输入值时,就无法处理。

例如,我们来看以下程序:

def func_to_trace(x):if x.sum() 0:return torch.relu(x)else:return torch.neg(x)traced = torch.fx.symbolic_trace(func_to_trace)
"""<...>File "dyn.py", line 6, in func_to_traceif x.sum() 0:File "pytorch/torch/fx/proxy.py", line 155, in __bool__return self.tracer.to_bool(self)File "pytorch/torch/fx/proxy.py", line 85, in to_boolraise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if语句的条件依赖于x.sum()的值,而该值又依赖于函数输入x。由于x可能发生变化(例如向追踪函数传入新的输入张量时),这就形成了动态控制流。回溯信息会沿着代码向上追溯,展示这种情况发生的位置。


静态控制流

另一方面,系统支持所谓的静态控制流。静态控制流指的是那些在多次调用中值不会改变的循环或if语句。通常在PyTorch程序中,这种控制流出现在根据超参数决定模型架构的代码中。举个具体例子:

import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# This if-statement is so-called static control flow.# Its condition does not depend on any input valuesif self.do_activation:x = torch.relu(x)return xwithout_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):linear_1 = self.linear(x);  x = Nonereturn linear_1
"""traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):linear_1 = self.linear(x);  x = Nonerelu_1 = torch.relu(linear_1);  linear_1 = Nonereturn relu_1
"""

if self.do_activation 这个条件语句不依赖于任何函数输入,因此它是静态的。do_activation 可以被视为一个超参数,当 MyModule 的不同实例使用不同参数值时,生成的代码轨迹也会不同。这是一种有效模式,符号追踪功能支持这种模式。

许多动态控制流的实例在语义上其实是静态控制流。通过消除对输入值的数据依赖,这些实例可以支持符号追踪。具体方法包括:

  • 将值移至 Module 属性中
  • 在符号追踪期间将具体值绑定到参数上

def f(x, flag):if flag: return xelse: return x*2fx.symbolic_trace(f) # Fails!fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,包含此类代码的程序部分可以被追踪为对方法的调用(参见使用Tracer类自定义追踪)或函数调用(参见wrap()),而不是直接追踪这些代码本身。


torch函数

FX采用__torch_function__作为拦截调用的机制(更多技术细节请参阅技术概览)。某些函数(如Python内置函数或math模块中的函数)不受__torch_function__覆盖,但我们仍希望在符号追踪中捕获它们。例如:

import torch
import torch.fx
from math import sqrtdef normalize(x):"""Normalize `x` by the size of the batch dimension"""return x / sqrt(len(x))# It's valid Python code
normalize(torch.rand(3, 4))traced = torch.fx.symbolic_trace(normalize)
"""<...>File "sqrt.py", line 9, in normalizereturn x / sqrt(len(x))File "pytorch/torch/fx/proxy.py", line 161, in __len__raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误提示表明内置函数 len 不被支持。

我们可以通过 wrap() API 将此类函数记录为跟踪中的直接调用:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')traced = torch.fx.symbolic_trace(normalize)print(traced.code)
"""
import math
def forward(self, x):len_1 = len(x)sqrt_1 = math.sqrt(len_1);  len_1 = Nonetruediv = x / sqrt_1;  x = sqrt_1 = Nonereturn truediv
"""

使用 Tracer 类自定义追踪功能

Tracer 类是 symbolic_trace 功能的基础实现类。通过继承 Tracer 类,可以自定义追踪行为,例如:

class MyCustomTracer(torch.fx.Tracer):# Inside here you can override various methods# to customize tracing. See the `Tracer` API# referencepass# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):def forward(self, x):return torch.relu(x) + torch.ones(3, 4)mod = MyModule()traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶子模块

叶子模块是指在符号追踪过程中作为调用出现,而不会被继续追踪的模块。默认的叶子模块集合由标准torch.nn模块实例组成。例如:

class MySpecialSubmodule(torch.nn.Module):def forward(self, x):return torch.neg(x)class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(3, 4)self.submod = MySpecialSubmodule()def forward(self, x):return self.submod(self.linear(x))traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):linear_1 = self.linear(x);  x = Noneneg_1 = torch.neg(linear_1);  linear_1 = Nonereturn neg_1
"""

可以通过重写 Tracer.is_leaf_module() 来自定义叶子模块集合。


杂项说明

  • 当前无法追踪张量构造函数(如torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor):
    • 确定性构造函数(zerosones)仍可使用,其生成的值会作为常量嵌入追踪记录。仅当这些构造函数的参数涉及动态输入大小时才会出现问题,此时可改用ones_likezeros_like作为替代方案。
    • 非确定性构造函数(randrandn)会将单个随机值嵌入追踪记录,这通常不符合预期行为。变通方法是将torch.randn包装在torch.fx.wrap函数中并调用该包装函数。

(注:保留所有代码块及技术术语原貌,被动语态转为主动表述,长句拆分后保持技术严谨性)


@torch.fx.wrap
def torch_randn(x, shape):return torch.randn(shape)def f(x):return x + torch_randn(x, 5)
fx.symbolic_trace(f)

此行为可能在未来的版本中修复。

  • 类型注解
  • 支持 Python 3 风格的类型注解(例如
    func(x : torch.Tensor, y : int) -torch.Tensor),
    并且会通过符号追踪保留这些注解。

  • 目前不支持 Python 2 风格的注释类型注解
    # type: (torch.Tensor, int) -torch.Tensor

  • 目前不支持函数内部局部变量的类型注解。

  • 关于 training 标志和子模块的注意事项
  • 当使用像 torch.nn.functional.dropout 这样的函数时,通常会传入 self.training 作为训练参数。在 FX 追踪过程中,这个值很可能会被固定为一个常量。

import torch
import torch.fxclass DropoutRepro(torch.nn.Module):def forward(self, x):return torch.nn.functional.dropout(x, training=self.training)traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
"""
def forward(self, x):dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = Nonereturn dropout
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
"""
AssertionError: Tensor-likes are not close!Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
"""

然而,当使用标准的 nn.Dropout() 子模块时,训练标志会被封装起来,并且由于保留了 nn.Module 对象模型,可以对其进行修改。


class DropoutRepro2(torch.nn.Module):def __init__(self):super().__init__()self.drop = torch.nn.Dropout()def forward(self, x):return self.drop(x)traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
"""
def forward(self, x):drop = self.drop(x);  x = Nonereturn drop
"""traced.eval()x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)

由于这一差异,建议将与动态training标志交互的模块标记为叶模块。


API 参考


torch.fx.symbolic_trace(root, concrete_args=None)

符号追踪 API

给定一个 nn.Module 或函数实例 root,该 API 会返回一个 GraphModule,这是通过记录追踪 root 时观察到的操作构建而成的。

concrete_args 参数允许你对函数进行部分特化,无论是为了移除控制流还是数据结构。

例如:

def f(a, b):if b == True:return a     else:return a * 2

由于控制流的存在,FX通常无法追踪此过程。不过,我们可以使用concrete_args来针对变量b的值进行特化处理,从而实现追踪:

f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6

请注意,虽然您仍可以传入不同的b值,但这些值将被忽略。

我们还可以使用concrete_args来消除函数中对数据结构的处理。这将利用pytrees来展平您的输入。为了避免过度特化,对于不应特化的值,请传入fx.PH。例如:

def f(x):out = 0for v in x.values():out += vreturn outf = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7

参数

  • root (Union[torch.nn.Module, Callable]) - 待追踪并转换为图表示形式的模块或函数
  • concrete_args (Optional[Dict[str, any]]) - 需要部分特化的输入参数

返回从root记录的操作所创建的模块。

返回类型:GraphModule

注意:此API保证向后兼容性。


torch.fx.wrap(fn_or_name)

该函数可在模块级作用域调用,将fn_or_name注册为"叶子函数"。

"叶子函数"在FX跟踪中会保留为CallFunction节点,而不会被进一步跟踪。


# foo/bar/baz.py
def my_custom_function(x, y):return x * x + y * ytorch.fx.wrap("my_custom_function")def fn_to_be_traced(x, y):# When symbolic tracing, the below call to my_custom_function will be inserted into# the graph rather than tracing it.return my_custom_function(x, y)

该函数也可以等效地用作装饰器:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):return x * x + y * y

包装函数可以被视为"叶子函数",类似于"叶子模块"的概念,也就是说,这些函数在FX跟踪中会保留为调用点,而不会被进一步追踪。

参数

  • fn_or_name (Union[str, Callable]) - 当被调用时,要插入到图中的函数或全局函数名称

注意:此API保证向后兼容性。


class torch.fx.GraphModule(*args, **kwargs)

GraphModule 是由 fx.Graph 生成的 nn.Module。GraphModule 具有一个 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告:当重新分配 graph 时,codeforward 将自动重新生成。但如果你编辑了 graph 的内容而没有重新分配 graph 属性本身,则必须调用 recompile() 来更新生成的代码。

注意:此 API 保证向后兼容性。


__init__(root, graph, class_name='GraphModule')

构建一个 GraphModule。

参数

  • root (Union[torch.nn.Module , Dict[str, Any])root 可以是 nn.Module 实例,也可以是将字符串映射到任意属性类型的字典。

root 是 Module 时,Graph 的 Nodes 中 target 字段对基于 Module 的对象(通过限定名称引用)的任何引用,都会从 root 的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。

root 是字典时,Node 的 target 中找到的限定名称将直接在字典的键中查找。字典映射的对象将被复制到 GraphModule 模块层次结构中的适当位置。

  • graph (Graph)graph 包含此 GraphModule 用于代码生成的节点
  • class_name (str)name 表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自 GraphModule。将其设置为 root 的原始名称或在转换上下文中合理的名称可能会有所帮助。

注意:此 API 保证向后兼容性。


add_submodule(target, m)

将给定的子模块添加到self中。

如果target是子路径且对应位置尚未存在模块,此方法会安装空的模块。

参数

  • target (str) - 新子模块的完整限定字符串名称
    (参见nn.Module.get_submodule中的示例了解如何指定完整限定字符串)
  • m (Module) - 子模块本身;即我们想要安装到当前模块中的实际对象

返回

子模块是否能够被插入。要使该方法返回True,target表示的链中每个对象必须满足以下条件之一:
a) 尚不存在,或
b) 引用的是nn.Module(而非参数或其他属性)

返回类型:bool

注意:此API保证向后兼容性。


property code:  str 

返回从该 GraphModule 底层 Graph 生成的 Python 代码。

delete_all_unused_submodules()

***
Deletes all unused submodules from `self`.A Module is considered “used” if any one of the following is true:
1、It has children that are used
2、Its forward is called directly via a `call_module` node
3、It has a non-Module attribute that is used from a `get_attr` nodeThis method can be called to clean up an `nn.Module` without
manually calling `delete_submodule` on each unused submodule.
***
Note: Backwards-compatibility for this API is guaranteed.delete_submodule(target)

self中删除指定的子模块。

如果target不是有效的目标,则不会删除该模块。

参数

  • target (str) - 新子模块的完全限定字符串名称
    (有关如何指定完全限定字符串的示例,请参阅nn.Module.get_submodule

返回值
表示目标字符串是否引用了我们要删除的子模块。返回值为False意味着target不是有效的子模块引用。

返回类型 : bool

注意:此API保证向后兼容性。


property graph: [Graph](https://pytorch.org/docs/stable/data.html#torch.fx.Graph "torch.fx.graph.Graph")

返回该 GraphModule 底层对应的 Graph


print_readable(print_output=True, include_stride=False, include_device=False, colored=False)

返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码

警告:此 API 为实验性质,且保证向后兼容性。


recompile()

根据其 graph 属性重新编译该 GraphModule。在编辑包含的 graph 后应调用此方法,否则该 GraphModule 生成的代码将过期。

注意:此 API 保证向后兼容性。

返回类型:PythonCode


to_folder(folder, module_name='FxModule')

将模块以 module_name 名称转储到 folder 目录下,以便可以通过 from <folder> import <module_name> 方式导入。

参数:

folder (Union [str, os.PathLike]): 用于输出代码的目标文件夹路径
module_name (str): 在输出代码时使用的顶层模块名称

警告:此 API 为实验性质,不保证向后兼容性。


class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)

Graph 是 FX 中间表示中使用的主要数据结构。

它由一系列 Node 组成,每个节点代表调用点(或其他语法结构)。这些 Node 的集合共同构成了一个有效的 Python 函数。

例如,以下代码


import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下图表:

print(gm.graph)

graph(x):%linear_weight : [num_users=1] = self.linear.weight%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})return topk_1

关于Graph中操作的具体语义,请参阅Node文档。

注意:本API保证向后兼容性。


__init__(owning_module=None, tracer_cls=None, tracer_extras=None)

构建一个空图。

注意:此 API 保证向后兼容性。


call_function(the_function, args=None, kwargs=None, type_expr=None)

Graph中插入一个call_function类型的Nodecall_function节点表示对Python可调用对象的调用,由the_function指定。

参数

  • the_function (Callable[...*, Any]) – 要调用的函数。可以是任何PyTorch运算符、Python函数,或属于builtinsoperator命名空间的成员。
  • args (Optional[Tuple[Argument*, ...]]) – 传递给被调用函数的位置参数。
  • kwargs (Optional[Dict[str, Argument]]) – 传递给被调用函数的关键字参数。
  • type_expr (Optional[Any]) – 可选的类型注解,表示该节点输出值的Python类型。

返回

新创建并插入的call_function节点。

返回类型

Node

注意:此方法的插入点和类型表达式规则与Graph.create_node()相同。

注意:此API保证向后兼容性。


call_method(method_name, args=None, kwargs=None, type_expr=None)

Graph中插入一个call_method节点。call_method节点表示对args第0个元素调用指定方法。

参数

  • method_name (str) - 要应用于self参数的方法名称。例如,如果args[0]是一个表示TensorNode,那么要对该Tensor调用relu()方法时,需将relu作为method_name传入。
  • args (Optional[Tuple[Argument*, ...]]) - 要传递给被调用方法的位置参数。注意这应该包含一个self参数。
  • kwargs (Optional[Dict[str, Argument]]) - 要传递给被调用方法的关键字参数
  • type_expr (Optional[Any]) - 可选的类型注解,表示该节点输出结果的Python类型。

返回

新创建并插入的call_method节点。

返回类型

Node

注意:本方法的插入点和类型表达式规则与Graph.create_node()相同。

注意:此API保证向后兼容性。


call_module(module_name, args=None, kwargs=None, type_expr=None)

Graph中插入一个call_module类型的Node节点。call_module节点表示对Module层级结构中某个Module的forward()函数的调用。

参数

  • module_name (str) - 要调用的Module在层级结构中的限定名称。例如,若被追踪的Module有一个名为foo的子模块,而该子模块又包含名为bar的子模块,则应以foo.bar作为module_name来调用该模块。
  • args (Optional[Tuple[Argument*, ...]]) - 传递给被调用方法的位置参数。注意:此处不应包含self参数。
  • kwargs (Optional[Dict[str, Argument]]) - 传递给被调用方法的关键字参数
  • type_expr (Optional[Any]) - 可选类型注解,表示该节点输出值的Python类型。

返回

新创建并插入的call_module节点。

返回类型:Node

注意:本方法的插入点与类型表达式规则与Graph.create_node()相同。

注意:本API保证向后兼容性。


create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)

创建一个 Node 并将其添加到当前插入点的 Graph 中。

注意:当前插入点可以通过 Graph.inserting_before()Graph.inserting_after() 进行设置。

参数

  • op (str) - 该节点的操作码。可选值包括 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。这些操作码的语义在 Graph 的文档字符串中有详细说明。
  • args (Optional[Tuple[Argument*, ...]]) - 该节点的参数元组。
  • kwargs (Optional[Dict[str, Argument]]) - 该节点的关键字参数。
  • name (Optional[str]) - 为 Node 指定的可选字符串名称。这将影响生成的 Python 代码中赋值给该节点的变量名。
  • type_expr (Optional[Any]) - 可选类型注解,表示该节点输出值的 Python 类型。

返回

新创建并插入的节点。

返回类型:Node

注意:此 API 保证向后兼容。


eliminate_dead_code(is_impure_node=None)

根据图中各节点的用户数量及是否具有副作用,移除所有死代码。调用前必须确保图已完成拓扑排序。

参数

  • is_impure_node (Optional[Callable[[Node],* [bool]]]) —— 用于判断节点是否为非纯函数的回调函数。若未提供该参数,则默认使用 Node.is_impure 方法。

返回值:返回布尔值,表示该过程是否导致图结构发生变更。

返回类型:bool

示例

在消除死代码前,下方表达式 a = x + 1 中的变量 a 无用户引用,因此可从图中安全移除而不影响结果。


def forward(self, x):a = x + 1return x + self.attr_1

消除死代码后,a = x + 1 已被移除,前向传播部分的其他代码保留不变。


def forward(self, x):return x + self.attr_1

警告:死代码消除机制虽然采用了一些启发式方法来避免删除具有副作用的节点(参见 Node.is_impure),但总体覆盖率非常不理想。因此,除非你明确知道当前 FX 计算图完全由无副作用的操作构成,或者自行提供了检测副作用节点的自定义函数,否则不应假设调用此方法是安全可靠的。

注意:本 API 保证向后兼容性。


erase_node(to_erase)

Graph中删除一个Node。如果该节点在Graph中仍被使用,将抛出异常。

参数

  • to_erase (Node) – 要从Graph中删除的Node

注意:此API保证向后兼容性。


find_nodes(*, op, target=None, sort=True)

支持快速查询节点

参数

  • op (str) – 操作名称
  • target (Optional[Target]) – 节点目标。对于call_function操作,target为必填项;其他操作中target为可选参数。
  • sort ([bool]) – 是否按节点在图中出现的顺序返回结果。

返回值:返回符合指定op和target条件的节点迭代器。

警告:此API为实验性质,且保证向后兼容。


get_attr(qualified_name, type_expr=None)

向图中插入一个 get_attr 节点。get_attr 类型的 Node 表示从 Module 层次结构中获取某个属性。

参数

  • qualified_name (str) - 要获取属性的全限定名称。例如,若被追踪的 Module 包含名为 foo 的子模块,该子模块又包含名为 bar 的子模块,而 bar 拥有名为 baz 的属性,则应将全限定名称 foo.bar.baz 作为 qualified_name 传入。
  • type_expr (Optional[Any]) - 可选的类型注解,用于表示该节点输出值的 Python 类型。

返回

新创建并插入的 get_attr 节点。

返回类型:Node

注意:本方法的插入点与类型表达式规则与 Graph.create_node 方法保持一致。

注意:此 API 保证向后兼容性。


graph_copy(g, val_map, return_output_node=False)

将给定图中的所有节点复制到 self 中。

参数

  • g (Graph) – 作为节点复制来源的原始图。
  • val_map (Dict[Node,* Node]) – 用于存储节点映射关系的字典,键为 g 中的节点,值为 self 中的对应节点。注意:val_map 可预先包含值以实现特定值的复制覆盖。

返回值:如果 g 包含输出节点,则返回 self 中与 g 输出值等效的值;否则返回 None

返回类型:Optional[Union [tuple [Argument, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , [dtype](tensor_attributes.html#torch.dtype "torch.dtype"), Tensor , device , memory_format , layout , OpOverload, [SymInt](torch.html#torch.SymInt "torch.SymInt"), SymBool , SymFloat ]]

注意:本API保证向后兼容性。


inserting_after(n=None)

设置 create_node 及相关方法在图中插入节点的位置。当在 with 语句中使用时,这会临时设置插入点,并在 with 语句退出时恢复原位置。


with g.inserting_after(n):...  # inserting after node n
...  # insert point restored to what it was previously
g.inserting_after(n)  #  set the insert point permanently

参数:

n (可选[Node]): 要在其之前插入的节点。如果为None,则会在整个图的起始位置之后插入。

返回:
一个资源管理器,它会在__exit__时恢复插入点。

注意:此API保证向后兼容性。


inserting_before(n=None)

设置 create_node 及相关方法在图中插入节点的基准位置。当在 with 语句中使用时,这将临时设置插入点,并在 with 语句退出时恢复原位置。


with g.inserting_before(n):...  # inserting before node n
...  # insert point restored to what it was previously
g.inserting_before(n)  #  set the insert point permanently

参数:

n (Optional[Node]): 要插入位置的前一个节点。如果为None,则会在整个图的起始位置前插入。

返回:
一个资源管理器,该管理器会在__exit__时恢复插入点。

注意:此API保证向后兼容性。


lint()

对该图执行多项检查以确保其结构正确。具体包括:

  • 检查节点是否具有正确的所有权(由本图所有)
  • 检查节点是否按拓扑顺序排列
  • 若该图拥有所属的GraphModule,则检查目标是否存在该GraphModule中

注:本API保证向后兼容性。


node_copy(node, arg_transform=<function Graph.<lambda>>)

将节点从一个图复制到另一个图中。arg_transform需要将节点所在图的参数转换为目标图(self)的参数。示例:

# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}for node in g.nodes:value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])

参数

  • node (Node) – 要复制到 self 中的节点。
  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将节点 argskwargs 中的 Node 参数转换为 self 中的等效参数。最简单的情况下,该函数应从原始图中节点到 self 的映射表中检索值。

返回类型:Node

注意:此 API 保证向后兼容性。


property nodes: _node_list

获取构成该图的所有节点列表。

请注意,这个Node列表是以双向链表的形式表示的。在迭代过程中进行修改(例如删除节点、添加节点)是安全的。

返回值:一个双向链表结构的节点列表。注意可以对该列表调用reversed方法来切换迭代顺序。


on_generate_code(make_transformer)

在生成 Python 代码时注册转换器函数

参数:

make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):返回待注册代码转换器的函数。

该函数由 on_generate_code 调用以获取代码转换器。

此函数的输入参数为当前已注册的代码转换器(若未注册则为 None),以便在不需要覆盖时使用。该机制可用于串联多个代码转换器。

返回值:一个上下文管理器,当在 with 语句中使用时,会自动恢复先前注册的代码转换器。


示例:

gm: fx.GraphModule = ...# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):return ["import pdb; pdb.set_trace()\n", body]# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(lambda current_trans: (lambda body: insert_pdb(current_trans(body) if current_trans else body))
)gm.recompile()
gm(inputs)  # drops into pdb

该函数也可作为上下文管理器使用,其优势在于能自动恢复之前注册的代码转换器。


# ... continue from previous examplewith gm.graph.on_generate_code(lambda _: insert_pdb):# do more stuff with `gm`...gm.recompile()gm(inputs)  # drops into pdb# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).

警告:此 API 为实验性质,且向后兼容。


output(result, type_expr=None)

output Node 插入到 Graph 中。output 节点代表 Python 代码中的 return 语句。result 是应当返回的值。

参数

  • result (Argument) – 要返回的值。
  • type_expr (Optional[Any]) – 可选的类型注解,表示此节点输出将具有的 Python 类型。

注意:此方法的插入点和类型表达式规则与 Graph.create_node 相同。

注意:此 API 保证向后兼容性。


output_node()

警告:此 API 为实验性质,且向后兼容。

返回值类型:Node


placeholder(name, type_expr=None, default_value)

在图中插入一个placeholder节点。placeholder表示函数的输入参数。

参数

  • name (str) - 输入值的名称。这对应于该Graph所表示函数的位置参数名称。
  • type_expr (Optional[Any]) - 可选的类型注解,表示该节点输出值的Python类型。在某些情况下(例如当函数后续用于TorchScript编译时),这是生成正确代码所必需的。
  • default_value (Any) - 该函数参数的默认值。注意:为了允许None作为默认值,当参数没有默认值时,应传递inspect.Signature.empty来指定。

返回类型:Node

注意:此方法的插入点和类型表达式规则与Graph.create_node相同。

注意:此API保证向后兼容性。


print_tabular()

以表格形式打印图的中间表示。注意:此API需要安装tabulate模块。

注:该API保证向后兼容性。


process_inputs(*args)

处理参数以便它们可以传递到 FX 计算图中。

警告:此 API 为实验性质,且向后兼容。


process_outputs(out)

警告:此 API 为实验性质,且向后兼容。


python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)

将这段Graph转换为有效的Python代码。

参数

  • root_module (str) – 用于查找限定名称目标的根模块名称。通常为’self’。

返回值:src: 表示该对象的Python源代码
globals: 包含src中全局名称及其引用对象的字典

返回类型:一个包含两个字段的PythonCode对象

注意:此API保证向后兼容性。


set_codegen(codegen)

警告:此 API 为实验性功能,且向后兼容。


class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)

Node 是表示 Graph 中单个操作的数据结构。在大多数情况下,Node 表示对各种实体的调用点,例如运算符、方法和模块(某些例外包括指定函数输入和输出的节点)。每个 Node 都有一个由其 op 属性指定的函数。不同 op 值的 Node 语义如下:

  • placeholder 表示函数输入。name 属性指定该值的名称。target 同样是参数的名称。args 包含:1) 空值,或 2) 表示函数输入默认参数的单个参数。kwargs 无关紧要。占位符对应于图形输出中的函数参数(例如 x)。
  • get_attr 从模块层次结构中检索参数。name 同样是获取结果后赋值的名称。target 是参数在模块层次结构中的完全限定名称。argskwargs 无关紧要。
  • call_function 将自由函数应用于某些值。name 同样是赋值目标的名称。target 是要应用的函数。argskwargs 表示函数的参数,遵循 Python 调用约定。
  • call_module 将模块层次结构中的 forward() 方法应用于给定参数。name 同前。target 是要调用的模块在模块层次结构中的完全限定名称。argskwargs 表示调用模块时的参数(不包括 self 参数*)。
  • call_method 调用值的方法。name 类似。target 是要应用于 self 参数的方法名称字符串。argskwargs 表示调用模块时的参数(包括 self 参数*)。
  • output 在其 args[0] 属性中包含跟踪函数的输出。这对应于图形输出中的 “return” 语句。

注意:此 API 保证向后兼容。


property all_input_nodes:  list ['Node'] 
Return all Nodes that are inputs to this Node. This is equivalent to iterating over `args` and `kwargs` and only collecting the values that are Nodes.Returns
List of `Nodes` that appear in the `args` and `kwargs` of this `Node`, in that order.append(x)

在图的节点列表中,将 x 插入到当前节点之后。

等价于调用 self.next.prepend(x)

参数

  • x (Node) – 要插入到当前节点后的节点。必须属于同一个图。

注意:此 API 保证向后兼容。


property args:  tuple [Union [tuple ['Argument', 
...], collections.abc.Sequence ['Argument'], collections.abc.Mapping[str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], 
...] 

Node的参数元组。参数的具体含义取决于节点的操作码(opcode)。更多信息请参阅Node文档字符串。

允许对此属性进行赋值操作。所有关于使用情况和用户的记录都会在赋值时自动更新。


format_node(placeholder_names=None, maybe_return_typename=None)

返回一个描述性的字符串表示形式self

该方法可不带参数使用,作为调试工具。

此函数也用于Graph__str__方法内部。placeholder_namesmaybe_return_typename中的字符串共同构成了该Graph所属GraphModule中自动生成的forward函数的签名。placeholder_namesmaybe_return_typename不应在其他情况下使用。

参数

  • placeholder_names (Optional[list[str]]) - 一个列表,用于存储表示生成的forward函数中占位符的格式化字符串。仅供内部使用。
  • maybe_return_typename (Optional[list[str]]) - 一个单元素列表,用于存储表示生成的forward函数输出的格式化字符串。仅供内部使用。

返回

如果1)我们在Graph__str__方法中将format_node用作内部辅助工具,且2)self是一个占位符Node,则返回None。否则,返回当前Node的描述性字符串表示形式。

返回类型:str

注意:此API保证向后兼容。


insert_arg(idx, arg)

在参数列表的指定索引位置插入一个位置参数。

参数

  • idx ( int ) – 要插入到self.args中元素之前的索引位置。
  • arg (Argument) – 要插入到args中的新参数值

注意:本API保证向后兼容性。


is_impure()

返回该操作是否为不纯操作,即判断其操作是否为占位符或输出,或者是否为不纯的call_functioncall_module

返回值:指示该操作是否不纯。

返回类型:bool

警告:此API为实验性质,且向后兼容。


property kwargs:  dict[str , Union [tuple ['Argument', 
...], collections.abc.Sequence['Argument'], collections.abc.Mapping, [str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]] 

Node的关键字参数字典。参数的解析取决于节点的操作码。更多信息请参阅Node文档字符串。

允许对此属性进行赋值。所有关于使用情况和用户的统计都会在赋值时自动更新。


property next: Node

返回链表中下一个Node节点。

返回值:链表中下一个Node节点。


normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)

返回经过标准化的Python目标参数。这意味着当normalize_to_only_use_kwargs为真时,args/kwargs将与模块/函数的签名匹配,并按位置顺序仅返回kwargs。

同时会填充默认值。不支持仅限位置参数或可变参数。

支持模块调用。

可能需要arg_typeskwarg_types来消除重载歧义。

参数

  • root (torch.nn.Module) – 用于解析模块目标的基模块
  • arg_types (Optional[Tuple[Any]]) – 参数的元组类型
  • kwarg_types (Optional[Dict[str, Any]]) – 关键字参数的字典类型
  • normalize_to_only_use_kwargs ([bool]) – 是否标准化为仅使用kwargs

返回
返回命名元组ArgsKwargsPair,若失败则返回None

返回类型
Optional[ArgsKwargsPair]

警告:该API为实验性质,不保证向后兼容。


prepend(x)

在图的节点列表中,在此节点前插入x。示例:

Before: p -selfbx -x -ax
After:  p -x -selfbx -ax

参数

  • x (Node) – 要放置在该节点之前的节点。必须是同一图的成员。

注意:此 API 保证向后兼容。


property prev: Node

返回链表中当前节点的前一个Node

返回值:链表中当前节点的前一个Node


replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)

将图中所有使用 self 的地方替换为节点 replace_with

参数

  • replace_with (Node) – 用于替换所有 self 的节点。
  • delete_user_cb (Callable) – 回调函数,用于判断是否应移除某个使用原 self 节点的用户节点。
  • propagate_meta ([bool]) – 是否将原节点 .meta 字段的所有属性复制到替换节点上。出于安全考虑,仅当替换节点本身没有 .meta 字段时才允许此操作。

返回值

返回受此变更影响的节点列表。

返回类型:list [Node]

注意:此 API 保证向后兼容。


replace_input_with(old_input, new_input)

遍历 self 的输入节点,将所有 old_input 实例替换为 new_input

参数

  • old_input (Node) – 需要被替换的旧输入节点。
  • new_input (Node) – 用于替换 old_input 的新输入节点。

注意:此 API 保证向后兼容性。


property stack_trace: Optional[str ] 

返回在追踪过程中记录的 Python 堆栈跟踪信息(如果有)。

当使用 fx.Tracer 进行追踪时,该属性通常由 Tracer.create_proxy 填充。若需在追踪过程中记录堆栈跟踪以用于调试,请在 Tracer 实例上设置 record_stack_traces = True

当使用 dynamo 进行追踪时,该属性默认会由 OutputGraph.create_proxy 填充。

stack_trace 的字符串末尾将包含最内层的调用帧。


update_arg(idx, arg)

更新现有位置参数以包含新值

调用后,self.args[idx] == arg 将成立。

参数

  • idx ( int ) - 要更新元素在 self.args 中的索引位置
  • arg (Argument) - 要写入 args 的新参数值

注意:此 API 保证向后兼容性。


update_kwarg(key, arg)

更新现有关键字参数以包含新值

arg。调用后,self.kwargs[key] == arg

参数

  • key (str) - 要更新的元素在self.kwargs中的键名
  • arg (Argument) - 要写入kwargs的新参数值

注意:此API保证向后兼容性。


class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())

Tracer 是实现 torch.fx.symbolic_trace 符号追踪功能的类。调用 symbolic_trace(m) 等价于执行 Tracer().trace(m)

可以通过继承 Tracer 类来覆盖追踪过程中的各种行为。具体可覆盖的行为详见该类方法的文档字符串。

注意:此 API 保证向后兼容。


call_module(m, forward, args, kwargs)

该方法定义了当Tracer遇到对nn.Module实例调用时的行为。

默认行为是通过is_leaf_module检查被调用的模块是否为叶子模块。如果是,则在Graph中生成指向mcall_module节点;否则正常调用该Module,并跟踪其forward函数中的操作。

可通过重写此方法实现自定义行为,例如:

  • 创建嵌套的追踪GraphModules
  • 实现跨Module边界追踪时的特殊处理

参数说明:

  • m (Module) - 当前被调用的模块实例
  • forward (Callable) - 待调用模块的forward()方法
  • args (Tuple) - 模块调用点的参数元组
  • kwargs (Dict) - 模块调用点的关键字参数字典

返回值:

  • 若生成call_module节点,则返回Proxy代理值
  • 否则返回模块调用的原始结果

返回类型:任意类型

注意:本API保证向后兼容性。


create_arg(a)

一种方法,用于指定在准备值作为Graph中节点的参数时追踪的行为。

默认行为包括:

1、遍历集合类型(如元组、列表、字典)并递归地对元素调用create_args

2、给定一个Proxy对象,返回底层IR Node的引用。

3、给定一个非Proxy的Tensor对象,为以下情况生成IR:

  • 对于Parameter,生成一个引用该Parameter的get_attr节点。
  • 对于非Parameter的Tensor,将该Tensor存储在一个特殊属性中,并引用该属性。

可以重写此方法以支持更多类型。


参数

  • a (Any) – 将被作为ArgumentGraph中使用的值。

返回值:将值a转换为适当的Argument

返回类型:Argument


注意:此API保证向后兼容。


create_args_for_root(root_fn, is_module, concrete_args=None)

root模块的签名创建对应的placeholder节点。该方法会检查root模块的签名并据此生成这些节点,同时支持*args**kwargs参数。

警告:此API为实验性质,且向后兼容。


create_node(kind, target, args, kwargs, name=None, type_expr=None)

根据给定的目标、参数、关键字参数和名称插入一个图节点。

该方法可以被重写,用于在节点创建过程中对使用的值进行额外检查、验证或修改。例如,可能希望禁止记录原地操作。

注意:此API保证向后兼容性。

返回类型:Node


create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

根据给定的参数创建一个节点,然后返回包裹在 Proxy 对象中的节点。

如果 kind = ‘placeholder’,则表示我们正在创建一个代表函数参数的节点。若需要编码默认参数,则使用 args 元组。对于 placeholder 类型的节点,args 在其他情况下为空。

注意:此 API 保证向后兼容性。


get_fresh_qualname(prefix)

获取一个基于前缀的新名称并返回。该函数确保生成的名称不会与图中现有属性发生冲突。

注意:此API保证向后兼容。

返回类型:str


getattr(attr, attr_val, parameter_proxy_cache)

该方法定义了当对nn.Module实例调用getattr时,该Tracer的行为表现。

默认情况下,其行为是返回该属性的代理值。同时会将代理值存入parameter_proxy_cache中,以便后续调用能复用该代理而非新建。

可通过重写此方法来实现不同行为——例如在查询参数时不返回代理。

参数说明:

  • attr (str) - 被查询的属性名称
  • attr_val (Any) - 该属性的值
  • parameter_proxy_cache (Dict[str, Any]) - 属性名到代理值的映射缓存

返回值:
getattr调用的返回结果。

警告:此API属于实验性质,且不保证向后兼容。


is_leaf_module(m, module_qualified_name)

一种用于判断给定nn.Module是否为"叶子"模块的方法。

叶子模块是指出现在IR(中间表示)中的原子单元,通过call_module调用进行引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块都属于叶子模块。除非通过本参数特别指定,否则其他所有模块都会被追踪并记录其组成操作。

参数说明:

  • m (Module) - 被查询的模块
  • module_qualified_name (str) - 该模块到根模块的路径。例如,若模块层级结构中子模块foo包含子模块bar,而bar又包含子模块baz,则该模块的限定名将显示为foo.bar.baz

返回类型:bool

注意:本API保证向后兼容性。


iter(obj)

当代理对象被迭代时调用,例如在控制流中使用时。通常我们不知道如何处理,因为我们不知道代理的值,但自定义跟踪器可以通过 create_node 向图节点附加更多信息,并可以选择返回一个迭代器。

注意:此 API 保证向后兼容性。

返回类型:迭代器


keys(obj)

当代理对象的 keys() 方法被调用时触发。这是在代理对象上调用 ** 时发生的情况。该方法应返回一个迭代器,如果 ** 需要在自定义追踪器中生效。

注意:此 API 保证向后兼容。

返回类型:任意


path_of_module(mod)

这是一个辅助方法,用于在root模块的层级结构中查找mod的限定名称。例如,如果root有一个名为foo的子模块,而foo又有一个名为bar的子模块,那么将bar传入此函数将返回字符串"foo.bar"。

参数

  • mod (str) – 需要获取限定名称的Module

返回类型:str

注意:此API保证向后兼容性。


proxy(node)

注意:此 API 保证向后兼容性。

返回类型:Proxy

to_bool(obj)

当代理对象需要转换为布尔值时调用,例如在控制流中使用时。通常我们无法确定如何处理,因为不知道代理的具体值,但自定义追踪器可以通过create_node向图节点附加更多信息,并选择返回一个值。

注意:此API保证向后兼容。

返回类型:bool


trace(root, concrete_args=None)

追踪 root 并返回对应的 FX Graph 表示形式。root 可以是 nn.Module 实例或 Python 可调用对象。

请注意,在此调用后,self.root 可能与传入的 root 不同。例如,当向 trace() 传递自由函数时,我们会创建一个 nn.Module 实例作为根节点,并添加嵌入的常量。

参数

  • root (Union[Module, Callable]) – 需要追踪的 Module 或函数。该参数保证向后兼容性。
  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。此参数为实验性功能,其向后兼容性作保证。

返回值:表示传入 root 语义的 Graph 对象。

返回类型:Graph

注意:此 API 保证向后兼容性。


class torch.fx.Proxy(node, tracer=None)

Proxy对象是Node包装器,在符号追踪过程中流经程序,并记录它们接触到的所有操作(包括torch函数调用、方法调用和运算符)到不断增长的FX Graph中。

如果需要进行图变换,您可以在原始Node上封装自己的Proxy方法,这样就可以使用重载运算符向Graph添加额外内容。

Proxy对象不可迭代。换句话说,如果在循环中或作为*args/**kwargs函数参数使用Proxy,符号追踪器会抛出错误。

有两种主要解决方法:

1、将不可追踪的逻辑提取到顶层函数中,并使用fx.wrap进行处理。
2、如果控制流是静态的(即循环次数基于某些超参数),可以保持代码在原位,并重构为类似形式:

for i in range(self.some_hyperparameter):indexed_item = proxied_value[i]

如需更深入了解 Proxy 的内部实现细节,请查阅 torch/fx/README.md 文件中的 “Proxy” 章节。


注意:本 API 保证向后兼容性。


class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)

解释器(Interpreter)会逐节点(Node-by-Node)执行FX图。这种模式在许多场景下非常有用,包括编写代码转换器以及分析过程。

通过重写Interpreter类中的方法,可以自定义执行行为。以下是按调用层次结构划分的可重写方法映射:

run()+-- run_node+-- placeholder()+-- get_attr()+-- call_function()+-- call_method()+-- call_module()+-- output()

示例

假设我们需要将所有 torch.neg 实例与 torch.sigmoid 互换(包括它们对应的 Tensor 方法等价形式)。我们可以通过如下方式继承 Interpreter 类:

class NegSigmSwapInterpreter(Interpreter):def call_function(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == torch.sigmoid:return torch.neg(args, *kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: Target, args: Tuple, kwargs: Dict) -Any:if target == "neg":call_self, args_tail = argsreturn call_self.sigmoid(args_tail, *kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())

参数

  • module ( torch.nn.Module ) – 待执行的模块
  • garbage_collect_values ([bool]) – 是否在模块执行过程中最后一次使用后删除值。这能确保执行期间内存使用最优。可以禁用此功能,例如通过查看Interpreter.env属性来检查执行中的所有中间值。
  • graph (Optional[Graph]) – 如果传入该参数,解释器将执行此图而非module.graph,并使用提供的模块参数来满足任何状态请求。

注意:此API保证向后兼容性。


boxed_run(args_list)

通过解释方式运行模块并返回结果。该过程采用"boxed"调用约定,即传递一个参数列表(这些参数会被解释器自动清除),从而确保输入张量能够及时释放。

注意:本API保证向后兼容性。


call_function(target, args, kwargs)

执行一个call_function节点并返回结果。

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅Node
  • args (Tuple) – 本次调用的位置参数元组
  • kwargs (Dict) – 本次调用的关键字参数字典

返回类型:任意类型

返回值: 函数调用返回的值

注意:此API保证向后兼容性。


call_method(target, args, kwargs)

执行一个 call_method 节点并返回结果。

参数

  • target (Target) – 该节点的调用目标。有关语义的详细信息,请参阅 Node
  • args (Tuple) – 该调用的位置参数元组
  • kwargs (Dict) – 该调用的关键字参数字典

返回类型:任意

返回值:方法调用返回的值

注意:此 API 保证向后兼容性。


call_module(target, args, kwargs)

执行一个call_module节点并返回结果。

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅
    Node
  • args (Tuple) – 本次调用的位置参数元组
  • kwargs (Dict) – 本次调用的关键字参数字典

返回类型:Any

返回值:模块调用返回的值

注意:此API保证向后兼容性。


fetch_args_kwargs_from_env(n)

从当前执行环境中获取节点nargskwargs具体值

参数

  • n (Node) – 需要获取argskwargs的目标节点

返回值
节点n对应的具体argskwargs

返回类型:Tuple[Tuple, Dict]

注意:本API保证向后兼容性


fetch_attr(target)

self.moduleModule 层级结构中获取一个属性。

参数

  • target (str) - 要获取属性的全限定名称

返回

该属性的值。

返回类型

任意类型

注意:此 API 保证向后兼容。


get_attr(target, args, kwargs)

执行一个 get_attr 节点。该操作会从 self.moduleModule 层级结构中获取属性值。

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅 Node
  • args (Tuple) – 本次调用的位置参数元组
  • kwargs (Dict) – 本次调用的关键字参数字典

返回值
获取到的属性值

返回类型
任意类型

注意:此 API 保证向后兼容性。


map_nodes_to_values(args, n)

递归遍历 args 并在当前执行环境中查找每个 Node 的具体值。

参数

  • args (Argument) – 需要查找具体值的数据结构
  • n (Node)args 所属的节点。仅用于错误报告。

返回类型:Optional[Union [tuple [Argument’, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , dtype, Tensor , device , memory_format , layout , OpOverload, SymInt, SymBool , SymFloat ]]

注意:此 API 保证向后兼容性。


output(target, args, kwargs)

执行一个output节点。该操作实际上只是获取output节点引用的值并返回它。

参数

  • target (Target) – 该节点的调用目标。有关语义详情请参阅
    Node
  • args (Tuple) – 本次调用的位置参数元组
  • kwargs (Dict) – 本次调用的关键字参数字典

返回值:输出节点引用的返回值

返回类型:任意类型

注意:此API保证向后兼容。


placeholder(target, args, kwargs)

执行一个placeholder节点。请注意这是有状态的:

Interpreter内部维护了一个针对run方法传入参数的迭代器,本方法会返回该迭代器的next()结果。

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅Node
  • args (Tuple) – 本次调用的位置参数元组
  • kwargs (Dict) – 本次调用的关键字参数字典

返回值:获取到的参数值。

返回类型:任意类型

注意:此API保证向后兼容。


run(*args, initial_env=None, enable_io_processing=True)

通过解释执行模块并返回结果。

参数

  • *args – 按位置顺序传递给模块的运行参数
  • initial_env (Optional[Dict[Node, Any]]) – 可选的执行初始环境。这是一个将节点映射到任意值的字典。例如,可用于预先填充某些节点的结果,从而在解释器中仅进行部分求值。
  • enable_io_processing ([bool]) – 如果为true,我们会在使用输入和输出之前,先用图的process_inputs和process_outputs函数对它们进行处理。

返回值:执行模块后返回的值

返回类型:任意

注意:此API保证向后兼容。


run_node(n)

运行特定节点 n 并返回结果。

根据 node.op 的类型,调用对应的占位符、get_attr、call_function、call_method、call_module 或 output 方法。

参数

  • n (Node) – 需要执行的节点

返回值:执行节点 n 的结果

返回类型:任意类型

注意:此 API 保证向后兼容性。


class torch.fx.Transformer(module)

Transformer 是一种特殊类型的解释器,用于生成新的 Module。它提供了一个 transform() 方法,返回转换后的 Module。与 Interpreter 不同,Transformer 不需要参数即可运行,完全基于符号化方式工作。


示例

假设我们需要将所有 torch.neg 实例与 torch.sigmoid 互换(包括它们的 Tensor 方法等效形式)。可以通过如下方式子类化 Transformer

class NegSigmSwapXformer(Transformer):def call_function(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == torch.sigmoid:return torch.neg(*args, **kwargs)return super().call_function(target, args, kwargs)def call_method(self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:if target == "neg":call_self, *args_tail = argsreturn call_self.sigmoid(*args_tail, **kwargs)return super().call_method(target, args, kwargs)def fn(x):return torch.sigmoid(x).neg()gm = torch.fx.symbolic_trace(fn)transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())

参数

  • module ([GraphModule](https://pytorch.org/docs/stable/data.html#torch.fx.GraphModule "torch.fx.GraphModule")) – 待转换的Module对象。

注意:此API保证向后兼容性。


call_function(target, args, kwargs)

注意:该 API 保证向后兼容。

返回类型

Any


call_module(target, args, kwargs)

注意:此 API 保证向后兼容。

返回类型

Any


get_attr(target, args, kwargs)

执行一个 get_attr 节点。在 Transformer 中,该方法被重写以便向输出图中插入新的 get_attr 节点。

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅
    Node
  • args (Tuple) – 该调用的位置参数元组
  • kwargs (Dict) – 该调用的关键字参数字典

返回类型
Proxy

注意:此 API 保证向后兼容。


placeholder(target, args, kwargs)

执行一个 placeholder 节点。在 Transformer 中,该方法被重写以便向输出图中插入新的 placeholder

参数

  • target (Target) – 该节点的调用目标。关于语义的详细信息请参阅 Node
  • args (Tuple) – 该调用的位置参数元组
  • kwargs (Dict) – 该调用的关键字参数字典

返回类型:Proxy

注意:此 API 保证向后兼容。


transform()

转换 self.module 并返回转换后的 GraphModule

注意:此 API 保证向后兼容性。

返回类型 : GraphModule


torch.fx.replace_pattern(gm, pattern, replacement)

在GraphModule的图结构(gm)中,匹配所有可能的非重叠运算符集及其数据依赖关系(pattern),然后将每个匹配到的子图替换为另一个子图(replacement)。

参数

  • gm (GraphModule) - 封装待操作图的GraphModule
  • pattern (Union[Callable, GraphModule]) - 需要在gm中匹配并替换的子图
  • replacement (Union[Callable, GraphModule]) - 用于替换pattern的子图

返回值:返回一个Match对象列表,表示原始图中与pattern匹配的位置。如果没有匹配项则返回空列表。Match定义如下:

class Match(NamedTuple):# Node from which the match was foundanchor: Node# Maps nodes in the pattern subgraph to nodes in the larger graphnodes_map: Dict[Node, Node]

返回类型:List[Match]


示例:

import torch
from torch.fx import symbolic_trace, subgraph_rewriterclass M(torch.nn.Module):def __init__(self) -None:super().__init__()def forward(self, x, w1, w2):m1 = torch.cat([w1, w2]).sum()m2 = torch.cat([w1, w2]).sum()return x + torch.max(m1) + torch.max(m2)def pattern(w1, w2):return torch.cat([w1, w2])def replacement(w1, w2):return torch.stack([w1, w2])traced_module = symbolic_trace(M())subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码会先在 traced_moduleforward 方法中匹配 pattern。模式匹配基于使用-定义关系而非节点名称进行。例如,若 pattern 中包含 p = torch.cat([a, b]),则可以在原始 forward 函数中匹配到 m = torch.cat([a, b]),即使变量名不同(pm)也不影响。

pattern 中的 return 语句仅根据其值进行匹配,它可能与更大图中的 return 语句匹配,也可能不匹配。换句话说,模式不必延伸至更大图的末尾。

当模式匹配成功时,它将从更大的函数中被移除,并由 replacement 替换。如果更大函数中存在多个 pattern 匹配项,每个非重叠的匹配项都会被替换。若出现匹配重叠的情况,则替换重叠匹配集中最先找到的匹配项(此处的"最先"定义为节点使用-定义关系拓扑排序中的第一个节点。大多数情况下,第一个节点是紧接 self 后出现的参数,而最后一个节点是函数返回的内容)。

需要特别注意:pattern 可调用对象的参数必须在该可调用对象内部使用,且 replacement 可调用对象的参数必须与模式匹配。第一条规则解释了为何上述代码块中 forward 函数有参数 x, w1, w2,而 pattern 函数只有参数 w1, w2——因为 pattern 未使用 x,故不应将 x 指定为参数。

关于第二条规则的示例,考虑替换…


def pattern(x, y):return torch.neg(x) + torch.relu(y)

with


def replacement(x, y):return torch.relu(x)

在这种情况下,replacement需要与pattern相同数量的参数(包括xy),即使参数yreplacement中并未使用。

调用subgraph_rewriter.replace_pattern后,生成的Python代码如下所示:

def forward(self, x, w1, w2):stack_1 = torch.stack([w1, w2])sum_1 = stack_1.sum()stack_2 = torch.stack([w1, w2])sum_2 = stack_2.sum()max_1 = torch.max(sum_1)add_1 = x + max_1max_2 = torch.max(sum_2)add_2 = add_1 + max_2return add_2

注意:该 API 保证向后兼容。



torch.fx.experimental


警告:这些API属于实验性质,可能会随时变更而不另行通知。


torch.fx.experimental.symbolic_shapes

ShapeEnv
DimDynamic控制如何为维度分配符号。
StrictMinMaxConstraint对客户端:该维度的大小必须在’vr’范围内(指定包含性上下界),且必须为非负数且不应为0或1(但参见下方注意事项)。
RelaxedUnspecConstraint对客户端:无显式约束;约束由追踪过程中的守卫隐式推断得出。
EqualityConstraint表示并判定输入源之间的各类相等性约束。
SymbolicContext数据结构,指定在create_symbolic_sizes_strides_storage_offset中如何创建符号;例如,应为静态还是动态。
StatelessSymbolicContext通过DimDynamicDimConstraint给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset中创建符号。
StatefulSymbolicContext通过Source:Symbol缓存给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset中创建符号。
SubclassSymbolicContext可追踪张量子类的内部张量的正确符号上下文可能与外部符号上下文不同。
DimConstraints针对符号维度约束系统的自定义求解器。
ShapeEnvSettings封装所有可能影响FakeTensor调度的形状环境设置。
ConvertIntKey
CallMethodKey
PropagateUnbackedSymInts
DivideByKey
InnerTensorKey
hint_int获取整数的提示值(基于运行时观察到的底层实际值)。
is_concrete_int检查SymInt底层对象是否为具体值的实用工具。
is_concrete_bool检查SymBool底层对象是否为具体值的实用工具。
is_concrete_float检查SymInt底层对象是否为具体值的实用工具。
has_free_symbolsbool(free_symbols(val))的快速版本
has_free_unbacked_symbolsbool(free_unbacked_symbols(val))的快速版本
definitely_true仅当能确定a为True时返回True,过程中可能引入守卫。
definitely_false仅当能确定a为False时返回True,过程中可能引入守卫。
guard_size_oblivious以无视大小的方式对符号布尔表达式执行守卫。
sym_eq类似==,但在列表/元组上运行时,会递归测试相等性并使用sym_and连接结果,不引入守卫。
constrain_range应用约束使传入的SymInt必须在min-max范围内(包含边界),且不引入SymInt的守卫(意味着可用于未绑定的SymInt)。
constrain_unify给定两个SymInt,约束它们必须相等。
canonicalize_bool_expr通过将布尔表达式转换为lt/le不等式并将所有非常量项移至右侧,实现规范化。
statically_known_true如果x可简化为常量且为真,则返回True。
lru_cache
check_consistent测试两个"meta"值(通常为Tensor或SymInt)是否具有相同的值,例如在重追踪后。
compute_unbacked_bindings在运行fake tensor传播并生成example_value结果后,遍历example_value查找新绑定的未支持符号并记录其路径供后续使用。
rebind_unbacked假设我们正在重追踪一个已有FX图,该图先前进行过fake tensor传播(因此存在未支持的SymInt)。
resolve_unbacked_bindings
is_accessor_node

torch.fx.experimental.proxy_tensor

make_fx给定函数f,返回一个新函数。当使用有效参数执行该函数时,会返回一个FX GraphModule,表示执行过程中所执行的操作集合。
handle_sym_dispatch调用当前活动的代理跟踪模式,对操作SymInt/SymFloat/SymBool参数的函数进行符号调度跟踪。
get_proxy_mode获取当前活动的代理跟踪模式,如果当前未处于跟踪状态则返回None。
maybe_enable_thunkify在此上下文管理器内,如果正在进行make_fx跟踪,将对所有SymNode计算进行thunkify处理,并避免将其跟踪到图中,除非确实需要。
maybe_disable_thunkify在某个上下文中禁用thunkification功能。


torch.hub

PyTorch Hub 是一个预训练模型仓库,旨在促进研究可复现性。


发布模型

PyTorch Hub 支持通过添加简单的 hubconf.py 文件,将预训练模型(模型定义和预训练权重)发布到 GitHub 仓库。

hubconf.py 可以包含多个入口点。每个入口点都定义为 Python 函数(例如:您想发布的预训练模型)。


def entrypoint_name(args, *kwargs):# args & kwargs are optional, for models which take positional/keyword arguments....

如何实现入口点?

以下代码片段展示了如果我们扩展 pytorch/vision/hubconf.py 中的实现,如何为 resnet18 模型指定入口点。在大多数情况下,只需在 hubconf.py 中导入正确的函数就足够了。这里我们使用扩展版本作为示例来说明其工作原理。

完整脚本可查看 pytorch/vision 代码库

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18# resnet18 is the name of entrypoint
def resnet18(pretrained=False, *kwargs):""" # This docstring shows up in hub.help()Resnet18 modelpretrained (bool): kwargs, load pretrained weights into the model"""# Call the model, load pretrained weightsmodel = _resnet18(pretrained=pretrained, *kwargs)return model

  • dependencies 变量是一个列表,包含加载模型所需的包名。注意这里可能与训练模型所需的依赖项略有不同。
  • argskwargs 会传递给实际的可调用函数。
  • 函数的文档字符串(docstring)将作为帮助信息。它需要说明模型的功能以及允许的位置参数/关键字参数。强烈建议在此处添加几个示例。
  • 入口函数可以返回一个模型(nn.Module),也可以返回辅助工具(如分词器)来优化用户工作流程。
  • 以下划线开头的可调用对象被视为辅助函数,不会出现在 torch.hub.list() 的返回结果中。
  • 预训练权重可以存储在GitHub仓库本地,也可以通过 torch.hub.load_state_dict_from_url() 加载。如果小于2GB,建议将其附加到项目发布版并使用发布版的URL。

在上面的示例中,torchvision.models.resnet.resnet18 处理了 pretrained 参数,你也可以将以下逻辑放在入口函数定义中。


    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pthdirname = os.path.dirname(__file__)checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)state_dict = torch.load(checkpoint)model.load_state_dict(state_dict)# For checkpoint saved elsewherecheckpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 发布的模型至少应位于分支/标签中,不能是随机提交。

从Hub加载模型

PyTorch Hub 提供了一系列便捷的API,帮助开发者探索Hub中所有可用模型:

  • 通过 torch.hub.list() 查看所有模型
  • 使用 torch.hub.help() 显示文档说明和示例
  • 调用 torch.hub.load() 加载预训练模型

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)

列出由 github 指定的代码仓库中所有可调用的入口点。

参数

  • github (str) – 格式为“repo_owner/repo_name[:ref]”的字符串,其中 ref(标签或分支)为可选。如果未指定 ref,则默认分支为 main(如果存在),否则为 master
    示例:‘pytorch/vision:0.10’
  • force_reload ([bool], 可选) – 是否丢弃现有缓存并强制重新下载。默认为 False
  • skip_validation ([bool], 可选) – 如果为 False,torchhub 会检查 github 参数指定的分支或提交是否确实属于该仓库所有者。此操作会向 GitHub API 发起请求;可通过设置 GITHUB_TOKEN 环境变量指定非默认的 GitHub 令牌。默认为 False
  • trust_repo ([bool],* str 或 *None)
    "check"TrueFalseNone
    此参数在 v1.12 版本引入,用于确保用户仅运行信任的仓库代码。
  • 如果为 False,会提示用户确认是否信任该仓库。
  • 如果为 True,仓库将被添加到信任列表并直接加载,无需明确确认。
  • 如果为 "check",会检查该仓库是否在缓存的信任列表中。若不存在,则回退到 trust_repo=False 的行为。
  • 如果为 None:会发出警告,提示用户将 trust_repo 设为 FalseTrue"check"。此选项仅为向后兼容保留,将在 v2.0 版本移除。默认为 None,未来 v2.0 版本将改为默认 "check"
  • verbose ([bool], 可选) – 如果为 False,则屏蔽关于命中本地缓存的消息。注意首次下载的消息无法屏蔽。默认为 True

返回
可用的可调用入口点列表

返回类型
list


示例

>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)

torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)

显示入口点 model 的文档字符串。

参数

  • github (str) – 格式为 <repo_owner/repo_name[:ref]> 的字符串,其中 ref(标签或分支)是可选的。如果未指定 ref,则默认分支为 main(如果存在),否则为 master

示例:‘pytorch/vision:0.10’

  • model (str) – 仓库 hubconf.py 中定义的入口点名称字符串
  • force_reload ([bool], 可选) – 是否丢弃现有缓存并强制重新下载。默认为 False
  • skip_validation ([bool], 可选) – 如果为 False,torchhub 将检查 github 参数指定的 ref 是否确实属于该仓库所有者。这将向 GitHub API 发出请求;您可以通过设置 GITHUB_TOKEN 环境变量来指定非默认的 GitHub 令牌。默认为 False
  • trust_repo ([bool], *str 或 *None)

"check"TrueFalseNone

此参数在 v1.12 版本引入,用于确保用户仅运行来自受信任仓库的代码。

  • 如果为 False,将提示用户确认是否信任该仓库。

  • 如果为 True,该仓库将被添加到受信任列表并直接加载,无需明确确认。

  • 如果为 "check",将检查该仓库是否在缓存的受信任仓库列表中。如果不在列表中,行为将回退到 trust_repo=False 选项。

  • 如果为 None:将发出警告,提示用户将 trust_repo 设置为 FalseTrue"check"。此选项仅用于向后兼容,将在 v2.0 版本移除。默认为 None,最终将在 v2.0 版本更改为 "check"


示例

>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))

torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)

从 GitHub 仓库或本地目录加载模型。

注意:加载模型是典型用例,但此功能也可用于加载其他对象,如分词器、损失函数等。

如果 source 为 ‘github’,则 repo_or_dir 应为 repo_owner/repo_name[:ref] 格式,其中 ref(标签或分支)为可选项。

如果 source 为 ‘local’,则 repo_or_dir 应为本地目录路径。

参数

  • repo_or_dir (str) – 如果 source 为 ‘github’,则应为 GitHub 仓库,格式为 repo_owner/repo_name[:ref](ref 为可选的标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定 ref,则默认分支为 main(如果存在),否则为 master

如果 source 为 ‘local’,则应为本地目录路径。

  • model (str) – 仓库/目录中 hubconf.py 文件定义的可调用对象(入口点)名称。
  • *args (可选) – 可调用对象 model 的对应参数。
  • source (str, 可选) – ‘github’ 或 ‘local’。指定如何解释 repo_or_dir。默认为 ‘github’。
  • trust_repo ([bool],* str 或 *None)

"check"TrueFalseNone

此参数在 v1.12 中引入,用于确保用户仅运行信任仓库中的代码。

  • 如果为 False,将提示用户确认是否信任该仓库。

  • 如果为 True,该仓库将被添加到信任列表并直接加载,无需明确确认。

  • 如果为 "check",将检查该仓库是否在缓存信任列表中。如果不在,则回退到 trust_repo=False 的行为。

  • 如果为 None:将发出警告,提示用户将 trust_repo 设置为 FalseTrue"check"。此选项仅用于向后兼容,将在 v2.0 中移除。默认为 None,最终将在 v2.0 中改为 "check"

  • force_reload ([bool], 可选) – 是否无条件强制重新下载 GitHub 仓库。如果 source = 'local' 则无效。默认为 False
  • verbose ([bool], 可选) – 如果为 False,则屏蔽有关命中本地缓存的消息。注意首次下载的消息无法屏蔽。如果 source = 'local' 则无效。默认为 True
  • skip_validation ([bool], 可选) – 如果为 False,torchhub 将检查 github 参数指定的分支或提交是否属于该仓库所有者。这将向 GitHub API 发出请求;您可通过设置 GITHUB_TOKEN 环境变量指定非默认的 GitHub 令牌。默认为 False
  • **kwargs (可选) – 可调用对象 model 的对应关键字参数。

返回

调用 model 可调用对象时,传入给定 *args**kwargs 的输出。


示例

>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")

torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)

将给定URL的对象下载到本地路径。

参数

  • url (str) - 要下载对象的URL
  • dst (str) - 对象将被保存的完整路径,例如/tmp/temporary_file
  • hash_prefix (str, 可选) - 如果不为None,下载文件的SHA256哈希值应以hash_prefix开头。默认值:None
  • progress ([bool], 可选) - 是否在标准错误输出中显示进度条。默认值:True

示例

>>> torch.hub.download_url_to_file(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", 
...     "/tmp/temporary_file", 
... )

torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)

从给定的URL加载Torch序列化对象。

如果下载的文件是zip压缩包,系统会自动解压。

如果对象已存在于model_dir目录中,则直接反序列化并返回。

model_dir的默认值为<hub_dir>/checkpoints,其中hub_dir是由get_dir()返回的目录路径。

参数说明

  • url (str) - 需要下载对象的URL地址
  • model_dir (str, 可选) - 保存对象的目录路径
  • map_location (可选) - 指定存储位置重映射的函数或字典(参见torch.load)
  • progress ([bool], 可选) - 是否在标准错误输出中显示进度条。默认值:True
  • check_hash ([bool], 可选) - 若为True,则URL中的文件名部分需遵循命名规范:
    filename-<sha256>.ext,其中<sha256>是文件内容SHA256哈希值的前8位或更多位数字。该哈希值用于确保唯一文件名并验证文件内容。默认值:False
  • file_name (str, 可选) - 下载文件的名称。若未设置,则使用URL中的文件名
  • weights_only ([bool], 可选) - 若为True,则仅加载权重而不加载复杂的pickle对象。建议用于不可信来源。详见load()说明

返回类型:dict[str, Any]

使用示例


>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

运行加载的模型:

请注意,torch.hub.load() 中的 *args**kwargs 用于实例化模型。加载模型后,如何了解该模型的功能?
建议的工作流程如下:

  • 使用 dir(model) 查看模型所有可用的方法
  • 通过 help(model.foo) 查看 model.foo 运行所需的参数

为了帮助用户无需反复查阅文档即可探索功能,我们强烈建议仓库维护者确保函数帮助信息清晰简洁。同时,提供一个最小可运行示例也非常有帮助。


下载的模型保存在哪里?

模型保存路径按以下顺序确定:
1、调用 hub.set_dir(<PATH_TO_HUB_DIR>) 设置的路径
2、若设置了环境变量 TORCH_HOME,则使用 $TORCH_HOME/hub
3、若设置了环境变量 XDG_CACHE_HOME,则使用 $XDG_CACHE_HOME/torch/hub
4、默认路径为 ~/.cache/torch/hub

可通过 torch.hub.get_dir() 获取当前保存路径


***
Get the Torch Hub cache directory used for storing downloaded models \& weights.If [`set_dir()`](https://pytorch.org/docs/stable/data.html#torch.hub.set_dir "torch.hub.set_dir") is not called, default path is `$TORCH_HOME/hub` where
environment variable `$TORCH_HOME` defaults to `$XDG_CACHE_HOME/torch`.
`$XDG_CACHE_HOME` follows the X Design Group specification of the Linux
filesystem layout, with a default value `~/.cache` if the environment
variable is not set.Return typestr torch.hub.set_dir(d)

可选设置用于保存下载模型和权重的 Torch Hub 目录。

参数

  • d (str) – 用于保存下载模型和权重的本地文件夹路径。

缓存逻辑

默认情况下,我们在加载文件后不会进行清理。如果文件已存在于 get_dir() 返回的目录中,Hub 会默认使用缓存。

用户可以通过调用 hub.load(..., force_reload=True) 强制重新加载。这将删除现有的 GitHub 文件夹和下载的权重文件,并重新初始化全新下载。当同一分支发布更新时,这个功能非常有用,用户可以及时获取最新版本。


已知限制:

Torch hub 的工作原理是将包当作已安装的包进行导入。Python 导入机制会带来一些副作用,例如你可能会在 Python 缓存 sys.modulessys.path_importer_cache 中看到新增条目,这是 Python 的正常行为。这也意味着,如果不同代码仓库包含同名的子包(通常是名为 model 的子包),在从不同仓库导入不同模型时可能会遇到导入错误。针对这类导入错误的解决方案是从 sys.modules 字典中移除冲突的子包,更多细节可参考 这个 GitHub issue。

需要特别说明的一个已知限制:用户无法同一个 Python 进程中加载同一代码仓库的两个不同分支。这就像在 Python 中安装两个同名的包一样,是不合理的做法。如果强行尝试,缓存机制可能会介入并带来意外结果。当然,在独立的进程中分别加载它们是完全可行的。



TorchScript

TorchScript 是一种将 PyTorch 代码转换为可序列化和可优化模型的方法。任何 TorchScript 程序都可以从 Python 进程中保存,并在不依赖 Python 的环境中加载运行。

我们提供了一系列工具,帮助开发者逐步将纯 Python 模型转换为可独立于 Python 运行的 TorchScript 程序,例如在独立的 C++ 程序中运行。这使得开发者能够继续使用熟悉的 Python 工具训练 PyTorch 模型,然后通过 TorchScript 将模型导出到生产环境。在生产环境中,由于性能和多线程方面的考虑,使用 Python 程序可能并不合适。

如需了解 TorchScript 的入门指南,请参阅 TorchScript 简介教程。

若想查看将 PyTorch 模型转换为 TorchScript 并在 C++ 中运行的完整示例,请参考 在 C++ 中加载 PyTorch 模型 教程。


创建 TorchScript 代码

script将函数转换为脚本
trace追踪函数并返回一个可执行对象或ScriptFunction,该对象将通过即时编译进行优化
script_if_tracing在追踪过程中首次调用时编译fn
trace_module追踪模块并返回一个可执行的ScriptModule,该模块将通过即时编译进行优化
fork创建异步任务执行函数,并返回对该执行结果的引用
wait强制完成torch.jit.Future[T]异步任务,返回任务结果
ScriptModuleC++ torch::jit::Module的包装器,包含方法、属性和参数
ScriptFunction功能上等同于ScriptModule,但表示单个函数且不包含任何属性或参数
freeze冻结ScriptModule,将子模块和属性内联为常量
optimize_for_inference执行一系列优化步骤,为推理目的优化模型
enable_onednn_fusion根据参数enabled启用或禁用onednn JIT融合
onednn_fusion_enabled返回onednn JIT融合是否启用
set_fusion_strategy设置融合过程中可能发生的特化类型和数量
strict_fusion如果推理中未融合所有节点或训练中未符号微分,则报错
save保存此模块的离线版本以供其他进程使用
load加载先前用torch.jit.save保存的ScriptModuleScriptFunction
ignore此装饰器向编译器表明应忽略函数或方法,保留为Python函数
unused此装饰器向编译器表明应忽略函数或方法,并替换为抛出异常
interface用于注解不同类型类或模块的装饰器
isinstance在TorchScript中提供容器类型细化
Attribute此方法是一个返回值的直通函数,主要用于向TorchScript编译器表明左侧表达式是具有type类型的类实例属性
annotate用于在TorchScript编译器中指定the_value的类型

混合使用追踪与脚本化

在多数情况下,追踪(tracing)或脚本化(scripting)都是将模型转换为 TorchScript 的更简便方式。根据模型不同部分的具体需求,可以组合使用这两种方法。

脚本化函数能够调用追踪生成的函数。当需要在简单前馈模型周围添加控制流逻辑时,这种方式特别有用。例如,序列到序列模型中的束搜索(beam search)通常会用脚本编写,但可以调用通过追踪生成的编码器模块。


示例(在脚本中调用追踪函数):

import torchdef foo(x, y):return 2 * x + ytraced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))@torch.jit.script
def bar(x):return traced_foo(x, x)

被追踪的函数可以调用脚本函数。当模型的大部分只是前馈网络,而其中一小部分需要控制流时,这非常有用。在被追踪函数调用的脚本函数内部,控制流会被正确保留。

示例(在被追踪函数中调用脚本函数):

import torch@torch.jit.script
def foo(x, y):if x.max() y.max():r = xelse:r = yreturn rdef bar(x, y, z):return foo(x, y) + ztraced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3)))

该组合方式同样适用于 nn.Module,它可以通过追踪生成一个子模块,该子模块可从脚本模块的方法中调用。

示例(使用追踪模块):

import torch
import torchvisionclass MyScriptModule(torch.nn.Module):def __init__(self):super().__init__()self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]).resize_(1, 3, 1, 1))self.resnet = torch.jit.trace(torchvision.models.resnet18(),   torch.rand(1, 3, 224, 224))def forward(self, input):return self.resnet(input - self.means)my_script_module = torch.jit.script(MyScriptModule())

TorchScript 语言

TorchScript 是 Python 的一个静态类型子集,因此许多 Python 特性可以直接应用于 TorchScript。详情请参阅完整的 TorchScript 语言参考。


内置函数与模块

TorchScript 支持使用大多数 PyTorch 函数和许多 Python 内置功能。完整支持函数列表请参阅 TorchScript 内置函数。


PyTorch 函数与模块

TorchScript 支持 PyTorch 提供的张量和神经网络函数子集。Tensor 上的大多数方法、torch 命名空间中的所有函数、torch.nn.functional 中的全部函数以及 torch.nn 中的大多数模块均可被 TorchScript 支持。

不支持的 PyTorch 函数和模块列表请参阅 TorchScript 不支持的 PyTorch 结构。


Python 函数与模块

TorchScript 支持许多 Python 的内置函数。
math 模块同样受支持(详见数学模块),但其他 Python 模块(无论是内置还是第三方)均不支持。


Python 语言参考对比

如需查看支持的 Python 功能完整列表,请参阅 Python 语言参考覆盖范围。


调试


禁用 JIT 进行调试

PYTORCH_JIT

设置环境变量 PYTORCH_JIT=0 将禁用所有脚本和追踪注解。当您的 TorchScript 模型中出现难以调试的错误时,可以通过此标志强制所有代码以原生 Python 方式运行。由于该标志会禁用 TorchScript(脚本化和追踪),您可以使用诸如 pdb 之类的工具来调试模型代码。例如:

@torch.jit.script
def scripted_fn(x : torch.Tensor):for i in range(12):x = x + xreturn xdef fn(x):x = torch.neg(x)import pdb; pdb.set_trace()return scripted_fn(x)traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),))
traced_fn(torch.rand(3, 4))

使用 pdb 调试此脚本时一切正常,但当我们调用 @torch.jit.script 函数时会失效。我们可以全局禁用 JIT 功能,这样就能将 @torch.jit.script 作为普通 Python 函数调用而不进行编译。如果上述脚本名为 disable_jit_example.py,可以通过以下方式调用:

$ PYTORCH_JIT=0 python disable_jit_example.py

这样我们就能像普通 Python 函数一样单步调试 @torch.jit.script 装饰的函数。如需禁用特定函数的 TorchScript 编译器,请参阅 @torch.jit.ignore


代码检查

TorchScript 为所有 ScriptModule 实例提供了代码美化打印器。该美化打印器能够将脚本方法的代码以有效的 Python 语法形式呈现。例如:

@torch.jit.script
def foo(len):# type: (int) -torch.Tensorrv = torch.zeros(3, 4)for i in range(len):if i < 10:rv = rv - 1.0else:rv = rv + 1.0return rvprint(foo.code)

一个包含单个 forward 方法的 ScriptModule 会有一个 code 属性,你可以通过该属性来检查 ScriptModule 的代码。

如果 ScriptModule 包含多个方法,你需要访问方法本身的 .code 属性,而不是模块的。我们可以通过访问 .foo.code 来检查 ScriptModule 上名为 foo 的方法代码。

上面的示例会产生以下输出:

def foo(len: int) -Tensor:rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)rv0 = rvfor i in range(len):if torch.lt(i, 10):rv1 = torch.sub(rv0, 1., 1)else:rv1 = torch.add(rv0, 1., 1)rv0 = rv1return rv0

这是 TorchScript 对 forward 方法代码的编译结果。

您可以通过它来验证 TorchScript(无论是通过追踪还是脚本化方式)是否正确捕获了您的模型代码。


解读图结构

TorchScript 在代码美化打印器之下还有一个更低层次的表示形式,即 IR(中间表示)图。

TorchScript 采用静态单赋值(SSA)中间表示(IR)来描述计算过程。这种格式的指令由 ATen(PyTorch 的 C++ 后端)运算符和其他基础运算符组成,包括用于循环和条件判断的控制流运算符。例如:

@torch.jit.script
def foo(len):# type: (int) -torch.Tensorrv = torch.zeros(3, 4)for i in range(len):if i < 10:rv = rv - 1.0else:rv = rv + 1.0return rvprint(foo.graph)

graph 遵循与代码检查章节中描述的相同规则,涉及 forward 方法查找。

上面的示例脚本生成如下图表:

graph(%len.1 : int):%24 : int = prim::Constant[value=1]()%17 : bool = prim::Constant[value=1]() # test.py:10:5%12 : bool? = prim::Constant()%10 : Device? = prim::Constant()%6 : int? = prim::Constant()%1 : int = prim::Constant[value=3]() # test.py:9:22%2 : int = prim::Constant[value=4]() # test.py:9:25%20 : int = prim::Constant[value=10]() # test.py:11:16%23 : float = prim::Constant[value=1]() # test.py:12:23%4 : int[] = prim::ListConstruct(%1, %2)%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10%rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5block0(%i.1 : int, %rv.14 : Tensor):%21 : bool = aten::lt(%i.1, %20) # test.py:11:12%rv.13 : Tensor = prim::If(%21) # test.py:11:9block0():%rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18-(%rv.3)block1():%rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18-(%rv.6)-(%17, %rv.13)return (%rv)

以指令%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10为例:

  • %rv.1 : Tensor表示我们将输出赋值给一个名为rv.1的(唯一)值,该值的类型为Tensor,且我们不知道其具体形状。
  • aten::zeros是运算符(相当于torch.zeros),输入列表(%4, %6, %6, %10, %12)指定了应传入哪些作用域中的值作为输入。像aten::zeros这样的内置函数的模式可以在内置函数中找到。
  • # test.py:9:10是生成此指令的原始源文件中的位置。在本例中,它位于名为test.py的文件中,第9行,第10个字符。

注意,运算符也可以关联blocks,即prim::Loopprim::If运算符。在图形打印输出中,这些运算符的格式会反映其等效的源代码形式,以便于调试。

可以按照所示方式检查图形,以确认由ScriptModule描述的计算是否正确,无论是自动还是手动方式,如下所述。


追踪器


追踪边界情况

在某些特殊情况下,对给定Python函数/模块的追踪可能无法准确反映底层代码的真实行为。这些情况包括:

  • 依赖于输入的控制流追踪(例如张量形状)
  • 张量视图就地操作的追踪(例如赋值语句左侧的索引操作)

请注意,这些情况未来实际上可能会变得可追踪。


自动追踪检查

自动捕获追踪中多种错误的一种方法是使用 torch.jit.trace() API 上的 check_inputs 参数。check_inputs 接收一个由输入元组组成的列表,这些输入将用于重新追踪计算并验证结果。例如:

def loop_in_traced_fn(x):result = x[0]for i in range(x.size(0)):result = result * x[i]return resultinputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs)

提供以下诊断信息:

ERROR: Graphs differed across invocations!
Graph diff:graph(%x : Tensor) {%1 : int = prim::Constant[value=0]()%2 : int = prim::Constant[value=0]()%result.1 : Tensor = aten::select(%x, %1, %2)%4 : int = prim::Constant[value=0]()%5 : int = prim::Constant[value=0]()%6 : Tensor = aten::select(%x, %4, %5)%result.2 : Tensor = aten::mul(%result.1, %6)%8 : int = prim::Constant[value=0]()%9 : int = prim::Constant[value=1]()%10 : Tensor = aten::select(%x, %8, %9)-   %result : Tensor = aten::mul(%result.2, %10)+   %result.3 : Tensor = aten::mul(%result.2, %10)?          ++%12 : int = prim::Constant[value=0]()%13 : int = prim::Constant[value=2]()%14 : Tensor = aten::select(%x, %12, %13)+   %result : Tensor = aten::mul(%result.3, %14)+   %16 : int = prim::Constant[value=0]()+   %17 : int = prim::Constant[value=3]()+   %18 : Tensor = aten::select(%x, %16, %17)-   %15 : Tensor = aten::mul(%result, %14)?     ^                                 ^+   %19 : Tensor = aten::mul(%result, %18)?     ^                                 ^-   return (%15);?             ^+   return (%19);?             ^}

这条消息表明,计算过程在我们首次追踪时和使用 check_inputs 进行追踪时出现了差异。实际上,loop_in_traced_fn 函数体中的循环依赖于输入 x 的形状,因此当我们尝试使用不同形状的另一个 x 时,追踪结果就会发生变化。

对于这种情况,可以使用 torch.jit.script() 来捕获此类数据依赖的控制流:

def fn(x):result = x[0]for i in range(x.size(0)):result = result * x[i]return resultinputs = (torch.rand(3, 4, 5),)
check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)]scripted_fn = torch.jit.script(fn)
print(scripted_fn.graph)
#print(str(scripted_fn.graph).strip())
for input_tuple in [inputs] + check_inputs:torch.testing.assert_close(fn(input_tuple), scripted_fn(input_tuple))

输出结果为:

graph(%x : Tensor) {%5 : bool = prim::Constant[value=1]()%1 : int = prim::Constant[value=0]()%result.1 : Tensor = aten::select(%x, %1, %1)%4 : int = aten::size(%x, %1)%result : Tensor = prim::Loop(%4, %5, %result.1)block0(%i : int, %7 : Tensor) {%10 : Tensor = aten::select(%x, %1, %i)%result.2 : Tensor = aten::mul(%7, %10)-(%5, %result.2)}return (%result);
}

追踪器警告

追踪器会对追踪计算中的几种问题模式产生警告。例如,假设对一个包含张量切片(视图)进行原地赋值的函数进行追踪:

def fill_row_zero(x):x[0] = torch.rand(x.shape[1:2])return xtraced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

生成多个警告信息和一个直接返回输入数据的图表


fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe.x[0] = torch.rand(x.shape[1:2])
fill_row_zero.py:6: TracerWarning: Output nr 1、of the traced function does not match the corresponding output of the Python function. Detailed error:
Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%)traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
graph(%0 : Float(3, 4)) {return (%0);
}

我们可以通过修改代码来解决这个问题,不再使用原地更新,而是使用torch.cat来非原地构建结果张量:

def fill_row_zero(x):x = torch.cat((torch.rand(1, x.shape[1:2]), x[1:2]), dim=0)return xtraced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),))
print(traced.graph)

常见问题解答

Q: 我想在GPU上训练模型,然后在CPU上进行推理。有哪些最佳实践?

首先将模型从GPU转换到CPU,然后保存它,如下所示:

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
torch.jit.save(traced_cpu, "cpu.pt")traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")# ... later, when using the model:if use_gpu:model = torch.jit.load("gpu.pt")
else:model = torch.jit.load("cpu.pt")model(input)

推荐采用此方式,因为追踪器可能会观测到张量在特定设备上创建的过程,直接转换已加载的模型可能产生意外效果。在保存模型之前进行类型转换,可确保追踪器获取正确的设备信息。

问:如何在ScriptModule上存储属性?

假设我们有如下模型:

import torchclass Model(torch.nn.Module):def __init__(self):super().__init__()self.x = 2def forward(self):return self.xm = torch.jit.script(Model())

如果直接实例化 Model 会导致编译错误,因为编译器无法识别 x。有四种方法可以让编译器识别 ScriptModule 上的属性:

1、nn.Parameter - 用 nn.Parameter 包装的值会像在 nn.Module 中一样正常工作。

2、register_buffer - 用 register_buffer 包装的值会像在 nn.Module 中一样正常工作。这相当于一个类型为 Tensor 的属性(见第4点)。

3、常量 - 将类成员标注为 Final(或在类定义级别将其添加到名为 __constants__ 的列表中)会将包含的名称标记为常量。常量会直接保存在模型的代码中。详情请参阅内置常量。

4、属性 - 支持类型的值可以作为可变属性添加。大多数类型可以自动推断,但有些可能需要明确指定,详情请参阅模块属性。

问题:我想追踪模块的方法,但一直遇到这个错误:

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient

这个错误通常意味着你正在追踪的方法使用了模块的参数,而你传递的是模块的方法而不是模块实例(例如 my_module_instance.forwardmy_module_instance)。

  • 使用模块的方法调用 trace 会将模块参数(可能需要梯度)捕获为常量
  • 另一方面,使用模块实例(例如 my_module)调用 trace 会创建一个新模块,并正确地将参数复制到新模块中,因此如果需要,它们可以累积梯度。

要追踪模块上的特定方法,请参阅 torch.jit.trace_module


已知问题

当你在 TorchScript 中使用 Sequential 时,某些 Sequential 子模块的输入可能会被错误推断为 Tensor,即使它们被标注为其他类型。标准解决方案是继承 nn.Sequential 并重新声明 forward 方法,确保输入类型正确。


附录


迁移至 PyTorch 1.2 递归脚本化 API

本节详细说明 PyTorch 1.2 中 TorchScript 的变化。如果你是 TorchScript 的新用户,可以跳过这部分内容。PyTorch 1.2 对 TorchScript API 主要做了两处改动:

1、torch.jit.script 现在会尝试递归编译遇到的函数、方法和类。一旦调用 torch.jit.script,编译过程将采用"默认启用"而非"手动启用"机制。

2、torch.jit.script(nn_module_instance) 现已成为创建 ScriptModule 的推荐方式,取代了原先继承 torch.jit.ScriptModule 的做法。

这些改动共同提供了一个更简单易用的 API,用于将你的 nn.Module 转换为可优化并在非 Python 环境中执行的 ScriptModule

新的使用方式如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))my_model = Model()
my_scripted_model = torch.jit.script(my_model)

  • 模块的 forward 方法默认会被编译。从 forward 中调用的方法会按它们在 forward 中的使用顺序进行惰性编译。
  • 若要编译未被 forward 调用的其他方法,需添加 @torch.jit.export 装饰器。
  • 如需阻止编译器编译某个方法,可添加 @torch.jit.ignore@torch.jit.unused@ignore 会保留 Python 方法调用,而 @unused 会将其替换为异常。@ignored 方法不可导出;@unused 方法可以导出。
  • 大多数属性类型可自动推断,因此无需使用 torch.jit.Attribute。对于空容器类型,建议使用 PEP 526 风格 的类注解来声明类型。
  • 常量可通过 Final 类注解标记,无需将成员名加入 __constants__ 列表。
  • 可用 Python 3 类型提示替代 torch.jit.annotate 函数。

基于这些变更,以下内容已被弃用,新代码中不应继续使用:

  • @torch.jit.script_method 装饰器
  • 继承自 torch.jit.ScriptModule 的类
  • torch.jit.Attribute 包装类
  • __constants__ 数组
  • torch.jit.annotate 函数

模块

警告:在 PyTorch 1.2 中,@torch.jit.ignore 注解的行为发生了变化。在 PyTorch 1.2 之前,@ignore 装饰器用于使函数或方法可以从导出的代码中调用。要恢复此功能,请使用 @torch.jit.unused()。现在 @torch.jit.ignore 等同于 @torch.jit.ignore(drop=False)。详情请参阅 @torch.jit.ignore@torch.jit.unused

当传递给 torch.jit.script 函数时,torch.nn.Module 的数据会被复制到 ScriptModule 中,并由 TorchScript 编译器编译该模块。默认情况下,模块的 forward 方法会被编译。从 forward 调用的方法会按照它们在 forward 中的使用顺序延迟编译,同时也会编译任何带有 @torch.jit.export 注解的方法。


torch.jit.export(fn)

这个装饰器用于标记nn.Module中的某个方法作为ScriptModule的入口点,该方法将被编译。

forward方法默认被视为入口点,因此不需要此装饰器。

forward调用的函数和方法会在编译器处理时自动编译,所以它们也不需要这个装饰器。

示例(在方法上使用@torch.jit.export装饰器):

import torch
import torch.nn as nnclass MyModule(nn.Module):def implicitly_compiled_method(self, x):return x + 99# `forward` is implicitly decorated with `@torch.jit.export`, # so adding it here would have no effectdef forward(self, x):return x + 10@torch.jit.exportdef another_forward(self, x):# When the compiler sees this call, it will compile# `implicitly_compiled_method`return self.implicitly_compiled_method(x)def unused_method(self, x):return x - 20# `m` will contain compiled methods:
#     `forward`
#     `another_forward`
#     `implicitly_compiled_method`
# `unused_method` will not be compiled since it was not called from # any compiled methods and wasn't decorated with `@torch.jit.export`
m = torch.jit.script(MyModule())

函数

函数基本保持不变,必要时可以使用 @torch.jit.ignoretorch.jit.unused 装饰器进行修饰。


# Same behavior as pre-PyTorch 1.2
@torch.jit.script
def some_fn():return 2# Marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():return 2# As with ignore, if nothing calls it then it has no effect.
# If it is called in script it is replaced with an exception.
@torch.jit.unused
def some_fn3():import pdb; pdb.set_trace()return 4# Doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn4():return 2

TorchScript 类

警告:TorchScript 类的支持目前处于实验阶段。当前最适合用于简单的记录式类型(可理解为附加了方法的NamedTuple)。

用户定义的 TorchScript 类中所有内容默认会被导出,如有需要可以使用 @torch.jit.ignore 装饰器来忽略特定函数。


属性

TorchScript 编译器需要知道模块属性的类型。大多数类型可以通过成员的值推断出来。空列表和字典无法推断其类型,必须使用 PEP 526 风格 的类注解显式标注类型。如果某个类型既无法推断又未显式标注,则不会将其作为属性添加到最终的 ScriptModule 中。

旧版 API:


from typing import Dict
import torchclass MyModule(torch.jit.ScriptModule):def __init__(self):super().__init__()self.my_dict = torch.jit.Attribute({}, Dict[str, int])self.my_int = torch.jit.Attribute(20, int)m = MyModule()

新API:

from typing import Dictclass MyModule(torch.nn.Module):my_dict: Dict[str, int]def __init__(self):super().__init__()# This type cannot be inferred and must be specifiedself.my_dict = {}# The attribute type here is inferred to be `int`self.my_int = 20def forward(self):passm = torch.jit.script(MyModule())

常量

Final 类型构造器可用于将成员标记为常量。如果成员未被标记为常量,它们将被复制到生成的 ScriptModule 中作为属性。使用 Final 可以在已知值固定的情况下开启优化机会,并提供额外的类型安全性。

旧版 API:

class MyModule(torch.jit.ScriptModule):__constants__ = ['my_constant']def __init__(self):super().__init__()self.my_constant = 2def forward(self):pass
m = MyModule()

新 API:

from typing import Finalclass MyModule(torch.nn.Module):my_constant: Final[int]def __init__(self):super().__init__()self.my_constant = 2def forward(self):passm = torch.jit.script(MyModule())

变量

容器默认具有 Tensor 类型且不可为空(更多信息请参阅默认类型章节)。之前使用 torch.jit.annotate 来告知 TorchScript 编译器类型信息,现在已支持 Python 3 风格的类型提示。


import torch
from typing import Dict, Optional@torch.jit.script
def make_dict(flag: bool):x: Dict[str, int] = {}x['hi'] = 2b: Optional[int] = Noneif flag:b = 2return x, b

融合后端

TorchScript 执行优化提供了几种融合后端选择。CPU 上的默认融合器是 NNC,它支持 CPU 和 GPU 的融合操作。而 GPU 上的默认融合器是 NVFuser,它支持更广泛的运算符,并已证明能生成具有更高吞吐量的内核。有关使用和调试的更多细节,请参阅 NVFuser 文档。


参考资料

  • Python 语言参考覆盖范围
  • TorchScript 不支持的 PyTorch 结构


torch.linalg

常用线性代数运算。

有关常见数值边界情况的说明,请参阅线性代数 (torch.linalg)。


矩阵属性

norm计算向量或矩阵范数
vector_norm计算向量范数
matrix_norm计算矩阵范数
diagonaltorch.diagonal()的别名,默认参数为dim1= -2, dim2= -1
det计算方阵的行列式
slogdet计算方阵行列式绝对值的符号和自然对数
cond计算矩阵关于某个矩阵范数的条件数
matrix_rank计算矩阵的数值秩

矩阵分解

cholesky计算复数厄米特矩阵或实数对称正定矩阵的Cholesky分解
qr计算矩阵的QR分解
lu计算带部分主元消去的矩阵LU分解
lu_factor计算带部分主元消去的矩阵LU分解的紧凑表示形式
eig计算方阵的特征值分解(如果存在)
eigvals计算方阵的特征值
eigh计算复数厄米特矩阵或实数对称矩阵的特征值分解
eigvalsh计算复数厄米特矩阵或实数对称矩阵的特征值
svd计算矩阵的奇异值分解(SVD)
svdvals计算矩阵的奇异值

求解器

solve计算具有唯一解的线性方程组的解。
solve_triangular计算具有唯一解的三角线性方程组的解。
lu_solve在给定LU分解的情况下,计算具有唯一解的线性方程组的解。
lstsq计算线性方程组的最小二乘解。

逆矩阵

inv计算方阵的逆矩阵(如果存在)。
pinv计算矩阵的伪逆(Moore-Penrose 逆)。

矩阵函数

matrix_exp计算方阵的矩阵指数
matrix_power计算方阵的整数n次幂

矩阵运算

cross计算两个三维向量的叉积
matmultorch.matmul() 的别名
vecdot沿指定维度计算两批向量的点积
multi_dot通过优化乘法顺序来高效计算两个及以上矩阵的连乘,实现最少算术运算
householder_product计算Householder矩阵乘积的前n列

张量运算

tensorinv计算 torch.tensordot() 的乘法逆元。
tensorsolve计算方程组 torch.tensordot(A, X) = B 的解 X。

杂项函数

vander生成范德蒙矩阵

实验性函数

cholesky_ex计算复厄米特矩阵或实对称正定矩阵的Cholesky分解。
inv_ex计算可逆方阵的逆矩阵。
solve_exsolve() 的一个变体,除非 check_errors= True,否则不执行错误检查。
lu_factor_ex这是 lu_factor() 的一个变体,除非 check_errors= True,否则不执行错误检查。
ldl_factor计算厄米特矩阵或对称矩阵(可能不定)的LDL分解的紧凑表示。
ldl_factor_ex这是 ldl_factor() 的一个变体,除非 check_errors= True,否则不执行错误检查。
ldl_solve使用LDL分解计算线性方程组的解。


torch.monitor


警告:本模块为原型版本,其接口和功能可能在未来的PyTorch版本中未经通知即发生变更。

torch.monitor 提供了从PyTorch记录事件和计数器的接口。

统计接口设计用于追踪高层次指标,这些指标会定期记录以监控系统性能。由于统计数据会按特定窗口大小进行聚合,您可以在关键循环中记录它们而对性能影响极小。

对于不频繁发生的事件或数值(如损失值、准确率、使用情况追踪),可以直接使用事件接口。

可以注册事件处理器来处理事件,并将其传递至外部事件接收器。


API 参考


class torch.monitor.Aggregation 

以下是可用的统计聚合类型:

成员说明:

VALUE :VALUE 返回最后添加的值。

MEAN :MEAN 计算所有添加值的算术平均值。

COUNT :COUNT 返回已添加值的总数量。

SUM :SUM 返回所有添加值的总和。

MAX :MAX 返回添加值中的最大值。

MIN :MIN 返回添加值中的最小值。


property name 

class torch.monitor.Stat 

Stat 用于在固定时间间隔内高效计算汇总统计量。Stat 会每隔 window_size 时长将统计结果记录为一个 Event 事件。当时间窗口关闭时,统计结果会通过事件处理器以 torch.monitor.Stat 事件的形式记录。

建议将 window_size 设置为较高的值(例如 60 秒),以避免记录过多事件。Stat 使用毫秒级精度。

如果设置了 max_samples 参数,Stat 会通过丢弃超出限制的 add 调用来限制每个窗口的最大样本数。未设置该参数时,窗口期内所有的 add 调用都会被纳入统计。这个可选字段主要用于在样本量可能波动的情况下,使不同窗口期的聚合数据更具可比性。

当 Stat 对象被销毁时,即使当前时间窗口尚未结束,它也会记录所有剩余数据。


__init__(self: torch._C._monitor.Stat, name:  str , aggregations:  list [torch._C._monitor.Aggregation], window_size: [datetime.timedelta, max_samples:  int  = 9223372036854775807)None  

构造 Stat 对象。

add(self: torch._C._monitor.Stat, v: float)None 

Adds a value to the stat to be aggregated according to the configured stat type and aggregations.


property count

Number of data points that have currently been collected. Resets
once the event has been logged.


get(self: torch._C._monitor.Stat)dict[torch._C._monitor.Aggregation, float]

Returns the current value of the stat, primarily for testing purposes. If the stat has logged and no additional values have been added this will be zero.


property name

The name of the stat that was set during creation.


class torch.monitor.data_value_t

data_value_t is one of str, float, int, bool.


class torch.monitor.Event

Event represents a specific typed event to be logged. This can represent high-level data points such as loss or accuracy per epoch or more low-level aggregations such as through the Stats provided through this library.

All Events of the same type should have the same name so downstream
handlers can correctly process them.

__init__(self: torch._C._monitor.Event, name: str, timestamp: datetime.datetime, data: dict[str, data_value_t])None

Constructs the Event.


property data

The structured data contained within the Event.


property name

The name of the Event.


property timestamp 

The timestamp when the Event happened.


class torch.monitor.EventHandlerHandle

EventHandlerHandle is a wrapper type returned by register_event_handler used to unregister the handler via unregister_event_handler. This cannot be directly initialized.


torch.monitor.log_event(event: torch._C._monitor.Event)None

log_event logs the specified event to all of the registered event handlers. It’s up to the event handlers to log the event out to the corresponding event sink.

If there are no event handlers registered this method is a no-op.


torch.monitor.register_event_handler(callback: Callable[[torch._C._monitor.Event], None ]) → torch._C._monitor.EventHandlerHandl)

register_event_handler registers a callback to be called whenever an event is logged via log_event. These handlers should avoid blocking the main thread since that may interfere with training as they run during the log_event call.


torch.monitor.unregister_event_handler(handler: torch._C._monitor.EventHandlerHandl))None

unregister_event_handler unregisters the EventHandlerHandle returned after calling register_event_handler. After this returns the event handler will no longer receive events.


class torch.monitor.TensorboardEventHandler(writer)

TensorboardEventHandler is an event handler that will write known events to the provided SummaryWriter.

This currently only supports torch.monitor.Stat events which are logged as scalars.


Example :

>>> from torch.utils.tensorboard import SummaryWriter>>> from torch.monitor import TensorboardEventHandler, register_event_handler>>> writer = SummaryWriter("log_dir")>>> register_event_handler(TensorboardEventHandler(writer))

__init__(writer)

构建 TensorboardEventHandler



torch.signal 模块

torch.signal 模块的设计灵感来源于 SciPy 的 signal 模块。


torch.signal.windows 窗口函数

bartlett计算巴特利特(Bartlett)窗口
blackman计算布莱克曼(Blackman)窗口
cosine计算余弦窗口,实现方式与SciPy保持一致
exponential计算指数窗口
gaussian计算高斯窗口
general_cosine计算广义余弦窗口
general_hamming计算广义汉明(Hamming)窗口
hamming计算汉明(Hamming)窗口
hann计算汉宁(Hann)窗口
kaiser计算凯撒(Kaiser)窗口
nuttall根据Nuttall方法计算最小4项布莱克曼-哈里斯(Blackman-Harris)窗口


torch.special

torch.special 模块的设计灵感来源于 SciPy 的 special 模块。


函数


torch.special.airy_ai(input, *, out=None) → Tensor  

Airy 函数 Ai(input)Ai(input)Ai(input)

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选) – 输出张量。

torch.special.bessel_j0(input, *, out=None) → Tensor  

第一类零阶贝塞尔函数。

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选 ) – 输出张量。

torch.special.bessel_j1(input, *, out=None) → Tensor  

第一类111阶贝塞尔函数。

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

torch.special.digamma(input, *, out=None) → Tensor  

计算输入张量的伽玛函数的对数导数。

ϝ(x)=ddxln⁡(Γ(x))=Γ′(x)Γ(x)\digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)}ϝ(x)=dxdln(Γ(x))=Γ(x)Γ(x)

参数

  • input ( Tensor ) – 用于计算digamma函数的输入张量

关键字参数

  • out ( Tensor , optional) – 输出张量

注意:此函数与SciPy的scipy.special.digamma功能相似。

注意:从PyTorch 1.8开始,digamma函数在输入为0时会返回-Inf,而此前版本会返回NaN。


示例:

>>> a = torch.tensor([1, 0.5])
>>> torch.special.digamma(a)
tensor([-0.5772, -1.9635])

torch.special.entr(input, *, out=None) → Tensor  

计算输入张量 input 中各元素的熵(定义如下)。

$$

\begin{align}

\text{entr(x)} = \begin{cases}
-x * \ln(x) & x 0 \

0 & x = 0.0 \
-\infty & x < 0

\end{cases}

\end{align}

$$

参数说明

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> a = torch.arange(-0.5, 1, 0.5)
>>> a tensor([-0.5000, 0.0000, 0.5000])
>>> torch.special.entr(a)
tensor([-inf, 0.0000, 0.3466])

torch.special.erf(input, *, out=None) → Tensor  

计算输入张量的误差函数。误差函数定义如下:

$$\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt

$$

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> torch.special.erf(torch.tensor([0, -1., 10.]))
tensor([0.0000, -0.8427, 1.0000])

torch.special.erfc(input, *, out=None) → Tensor  

计算输入张量的互补误差函数。

互补误差函数的定义如下:

$$

\mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e{-t2} dt

$$

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> torch.special.erfc(torch.tensor([0, -1., 10.]))
tensor([1.0000, 1.8427, 0.0000])

torch.special.erfcx(input, *, out=None) → Tensor  

计算input中每个元素的缩放互补误差函数。

缩放互补误差函数的定义如下:

erfcx(x)=ex2erfc(x)\mathrm{erfcx}(x) = e^{x^2} erfc(x) erfcx(x)=ex2erfc(x)

erfcx(x)=ex2erfc(x)

参数

  • input ( Tensor ) - 输入张量。

关键字参数

  • out ( Tensor , 可选) - 输出张量。

示例

>>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
tensor([1.0000, 5.0090, 0.0561])

torch.special.erfinv(input, *, out=None) → Tensor  

计算输入张量的反误差函数值。

反误差函数在区间 (−1,1) 内定义为:

erfinv(erf(x))=x\mathrm{erfinv(erf(x))}= x erfinv(erf(x))=x

参数说明

  • input ( Tensor ) - 输入张量

关键字参数

  • out ( Tensor , 可选) - 输出张量

使用示例


>>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
tensor([0.0000, 0.4769, -inf])

torch.special.exp2(input, *, out=None) → Tensor  

计算 input 的以 2 为底的指数函数。

yi=2xiy_{i} = 2^{x_{i}}yi=2xi

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选) – 输出张量。

示例:

>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
tensor([1., 2., 8., 16.])

torch.special.expit(input, *, out=None) → Tensor  

计算输入张量 input 各元素的 expit 值(也称为 logistic sigmoid 函数)。

outi=11+ei−input{out}_{i} = \frac{1}{1 + e^{-input}_{i}} outi=1+eiinput1

outi​=1+e−inputi​1​

参数

  • input ( Tensor ) - 输入张量。

关键字参数

  • out ( Tensor , 可选) - 输出张量。

示例:

>>> t = torch.randn(4)
>>> t
tensor([0.9213, 1.0887, -0.8858, -1.7683])
>>> torch.special.expit(t)
tensor([0.7153, 0.7481, 0.2920, 0.1458])

torch.special.expm1(input, *, out=None) → Tensor  

计算输入张量input各元素的指数值减1。

yi=exi−1y_{i} = e^{x_{i}} - 1yi=exi1

注意:对于较小的x值,该函数比直接计算exp(x) - 1能提供更高的精度。

参数说明

  • input ( Tensor ) - 输入张量。

关键字参数

  • out ( Tensor , 可选) - 输出张量。

使用示例


>>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
tensor([0., 1.])

torch.special.gammainc(input, other, *, out=None) → Tensor  

计算正则化的下不完全伽马函数:

$$

\text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt

$$

其中inputiinput_iinputiotheriother_iotheri均为弱正数且至少有一个严格为正数。若两者均为零或任一为负数,则outi=nan\text{out}_i=\text{nan}outi=nan。上述公式中的Γ\GammaΓ表示伽马函数:

$$

\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.

$$

相关函数请参阅torch.special.gammaincc()torch.special.gammaln()

支持广播至通用形状和浮点输入。

注意:目前不支持对input的反向传播。如需此功能,请在PyTorch的Github上提交issue。

参数

  • input ( Tensor ) – 第一个非负输入张量
  • other ( Tensor ) – 第二个非负输入张量

关键字参数

  • out ( Tensor , optional) – 输出张量

示例:

>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.3528, 0.5665, 0.7350])
tensor([0.3528, 0.5665, 0.7350])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])

torch.special.gammaincc(input, other, *, out=None) → Tensor  

计算正则化上不完全伽马函数:

$$\text{out}_{i} = \frac{1}{\Gamma(\text{input}i)} \int{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt

$$

其中 inputiinput_iinputi​ 和 otheriother_iotheri​ 均为弱正数,且至少有一个严格为正数。若两者均为零或任一为负数,则 outi=nanout_i=nanouti=nan。上式中的 Γ(⋅)\Gamma(\cdot)Γ() 表示伽马函数,

$$\Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.

$$

相关函数请参阅 torch.special.gammainc()torch.special.gammaln()

支持广播至相同形状及浮点输入。

注意:目前不支持对 input 的反向传播。如需此功能,请在 PyTorch 的 Github 上提交 issue。

参数

  • input ( Tensor ) – 第一个非负输入张量
  • other ( Tensor ) – 第二个非负输入张量

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> a1 = torch.tensor([4.0])
>>> a2 = torch.tensor([3.0, 4.0, 5.0])
>>> a = torch.special.gammaincc(a1, a2)
tensor([0.6472, 0.4335, 0.2650])
>>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
tensor([1., 1., 1.])

torch.special.gammaln(input, *, out=None) → Tensor  

计算输入张量绝对值的伽玛函数的自然对数。

outi=ln⁡Γ(∣inputi∣)\text{out}{i} = \ln \Gamma(|\text{input}{i}|)

outi​=lnΓ(∣inputi​∣)

参数

  • input ( Tensor ) - 输入张量。

关键字参数

  • out ( Tensor , optional) - 输出张量。

示例

>>> a = torch.arange(0.5, 2, 0.5)
>>> torch.special.gammaln(a)
tensor([0.5724, 0.0000, -0.1208])

torch.special.i0(input, *, out=None) → Tensor  

计算input中每个元素的第一类零阶修正贝塞尔函数。

$$\text{out}{i} = I_0(\text{input}{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}

$$

参数

  • input ( Tensor ) - 输入张量

关键字参数

  • out ( Tensor , 可选) - 输出张量

示例

>>> torch.i0(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 1.2661, 2.2796, 4.8808, 11.3019])

torch.special.i0e(input, *, out=None) → Tensor  

input 的每个元素计算指数缩放的第一类零阶修正贝塞尔函数(定义如下)。

$$\text{out}{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!)^2}

$$

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])

torch.special.i1(input, *, out=None) → Tensor  

计算input中每个元素的一阶第一类修正贝塞尔函数(定义如下)。

$$\text{out}_{i} = \exp(-|x|) * i1(x) =

\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}

$$

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例

>>> torch.special.i1(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])

torch.special.i1e(input, *, out=None) → Tensor  

计算input中每个元素的指数缩放一阶第一类修正贝塞尔函数(定义如下):

outi=exp⁡(−∣x∣)∗i1(x)=exp⁡(−∣x∣)∗(inputi)2∗∑k=0∞(inputi2/4)k(k!)∗(k+1)!\text{out}_{i} = \exp(-|x|) * i1(x) =

\exp(-|x|) * \frac{(\text{input}{i})}{2} * \sum{k=0}^{\infty} \frac{(\text{input}_{i}2/4)k}{(k!) * (k+1)!}

outi​=exp(−∣x∣)∗i1(x)=exp(−∣x∣)∗2(inputi​)​∗k=0∑∞​(k!)∗(k+1)!(inputi2​/4)k​

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])

torch.special.log1p(input, *, out=None) → Tensor  

torch.log1p() 的别名。


torch.special.log_ndtr(input, *, out=None) → Tensor  

计算标准高斯概率密度函数从负无穷到input的逐元素积分对数。

$$\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)

$$

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选) – 输出张量。

示例

>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841  -0.6931 -0.1728 -0.023  -0.0014])

torch.special.log_softmax(input, dim, *, dtype=None) → Tensor  

计算经过对数处理的softmax结果。

虽然在数学上等价于log(softmax(x)),但分开执行这两个操作会更慢且数值不稳定。该函数的计算方式如下:

log_softmax(xi)=log⁡(exp⁡(xi)∑jexp⁡(xj))\text{log_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)

log_softmax(xi​)=log(∑j​exp(xj​)exp(xi​)​)

参数说明

  • input ( Tensor ) – 输入张量
  • dim ( int ) – 指定计算log_softmax的维度
  • dtype ( torch.dtype , optional) – 返回张量的期望数据类型

如果指定该参数,在执行操作前会将输入张量转换为dtype类型。这有助于防止数据类型溢出。默认值:None。

使用示例:

>>> t = torch.ones(2, 2)
>>> torch.special.log_softmax(t, 0)
tensor([[-0.6931, -0.6931], [-0.6931, -0.6931]])

torch.special.logit(input, eps=None, *, out=None) → Tensor  

返回一个包含input元素logit值的新张量。

当eps不为None时,input会被截断到[eps, 1 - eps]区间。

当eps为None且input < 0或input > 1时,函数将返回NaN。

yi=ln⁡(zi1−zi)zi={xiif eps is Noneepsif xi<epsxiif eps≤xi≤1−eps1−epsif xi>1−eps\begin{align} y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\ z_{i} &= \begin{cases} x_{i} & \text{if eps is None} \\ \text{eps} & \text{if } x_{i} < \text{eps} \\ x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\ 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps} \end{cases} \end{align}yizi=ln(1zizi)=xiepsxi1epsif eps is Noneif xi<epsif epsxi1epsif xi>1eps

参数

  • input ( Tensor ) - 输入张量。
  • eps (float, 可选) - 用于输入截断的epsilon值。默认值:None

关键字参数

  • out ( Tensor , 可选) - 输出张量。

示例:

>>> a = torch.rand(5)
>>> a tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
>>> torch.special.logit(a, eps=1e-6)
tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])

torch.special.logsumexp(input, dim, keepdim=False, *, out=None) 

torch.logsumexp() 的别名。


torch.special.multigammaln(input, p, *, out=None) → Tensor  

计算给定维度 p 的多元对数伽玛函数,按元素逐个计算,公式如下:

$$\log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)

$$

其中 C=log⁡(π)⋅p(p−1)4C = \log(\pi) \cdot \frac{p (p - 1)}{4}C=log(π)4p(p1) 为伽玛函数。

所有元素必须大于 p−12\frac{p - 1}{2}2p1,否则行为未定义。

参数

  • input ( Tensor ) - 用于计算多元对数伽玛函数的张量
  • p ( int ) - 维度数量

关键字参数

  • out ( Tensor , optional) - 输出张量

示例

>>> a = torch.empty(2, 3).uniform_(1, 2)
>>> a tensor([[1.6835, 1.8474, 1.1929], [1.0475, 1.7162, 1.4180]])
>>> torch.special.multigammaln(a, 2)
tensor([[0.3928, 0.4007, 0.7586], [1.0311, 0.3901, 0.5049]])

torch.special.ndtr(input, *, out=None) → Tensor  

计算标准高斯概率密度函数从负无穷到输入值input的逐元素积分面积。

ndtr(x)=12π∫−∞xe−12t2dt\text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dtndtr(x)=2π1xe21t2dt

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例

>>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])

torch.special.ndtri(input, *, out=None) → Tensor  

计算高斯概率密度函数下(从负无穷积分到x)面积等于input各元素值的对应参数x。

ndtri(p)=2erf−1(2p−1)\text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1)ndtri(p)=2erf1(2p1)

注意:也称为正态分布的分位数函数。

参数说明:

  • input ( Tensor ) - 输入张量

关键字参数:

  • out ( Tensor , optional) - 输出张量

示例:

>>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])

torch.special.polygamma(n, input, *, out=None) → Tensor  

计算输入张量 input 的 digamma 函数的 nthn^{th}nth 阶导数。

其中 n≥0n \geq 0n0 称为多伽马函数的阶数。

$$\psi^{(n)}(x) = \frac{d{(n)}}{dx{(n)}} \psi(x)

$$

注意:此函数仅针对非负整数 n≥0 实现。

参数说明

  • n ( int ) – 多伽马函数的阶数
  • input ( Tensor ) – 输入张量

关键字参数

  • out ( Tensor , optional) – 输出张量

使用示例:

>>> a = torch.tensor([1, 0.5])
>>> torch.special.polygamma(1, a)
tensor([1.64493, 4.9348])
>>> torch.special.polygamma(2, a)
tensor([-2.4041, -16.8288])
>>> torch.special.polygamma(3, a)
tensor([6.4939, 97.4091])
>>> torch.special.polygamma(4, a)
tensor([-24.8863, -771.4742])

torch.special.psi(input, *, out=None) → Tensor  

torch.special.digamma() 的别名。


torch.special.round(input, *, out=None) → Tensor  

torch.round() 的别名。


torch.special.scaled_modified_bessel_k0(input, *, out=None) → Tensor  

二阶修正贝塞尔函数(阶数为0)。

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选) – 输出张量。

torch.special.scaled_modified_bessel_k1(input, *, out=None) → Tensor  

第二类111阶缩放修正贝塞尔函数。

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , 可选 ) – 输出张量。

torch.special.sinc(input, *, out=None) → Tensor  

计算 input 的归一化 sinc 函数值。

$$\text{out}_{i} =

\begin{cases}

1, & \text{if}\ \text{input}_{i}=0 \

\sin(\pi \text{input}{i}) / (\pi \text{input}{i}), & \text{otherwise}

\end{cases}

$$

参数

  • input ( Tensor ) - 输入张量。

关键字参数

  • out ( Tensor , optional) - 输出张量。

示例

>>> t = torch.randn(4)
>>> t
tensor([0.2252, -0.2948, 1.0267, -1.1566])
>>> torch.special.sinc(t)
tensor([0.9186, 0.8631, -0.0259, -0.1300])

torch.special.softmax(input, dim, *, dtype=None) → Tensor  

计算softmax函数。

Softmax的定义如下:

Softmax(xi)=exp⁡(xi)∑jexp⁡(xj)\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}Softmax(xi)=jexp(xj)exp(xi)

该函数会沿着指定维度dim对所有切片进行计算,并将结果重新缩放,使元素值落在[0, 1]区间且总和为1。

参数

  • input ( Tensor ) – 输入张量
  • dim ( int ) – 指定计算softmax的维度
  • dtype ( torch.dtype , 可选) – 返回张量的期望数据类型

若指定该参数,在执行操作前会将输入张量转换为dtype类型。这有助于防止数据类型溢出。默认值:None。

示例::


>>> t = torch.ones(2, 2)
>>> torch.special.softmax(t, 0)
tensor([[0.5000, 0.5000], [0.5000, 0.5000]])

torch.special.spherical_bessel_j0(input, *, out=None) → Tensor  

一阶球面贝塞尔函数(阶数为000)。

参数

  • input ( Tensor ) – 输入张量。

关键字参数

  • out ( Tensor , optional) – 输出张量。

torch.special.xlog1py(input, other, *, out=None) → Tensor  

计算 input * log1p(other),具体分为以下几种情况:

outi={NaNif otheri=NaN0if inputi=0.0and otheri!=NaNinputi∗log1p(otheri)otherwise\text{out}_{i} = \begin{cases} \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\ 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\ \text{input}_{i} * \text{log1p}(\text{other}_{i}) & \text{otherwise} \end{cases} outi=NaN0inputilog1p(otheri)if otheri=NaNif inputi=0.0 and otheri!=NaNotherwise

与 SciPy 的 scipy.special.xlog1py 功能类似。

参数

  • input (Number* 或 Tensor) – 乘数
  • other (Number* 或 Tensor) – 参数

注意inputother 中至少有一个必须是张量。

关键字参数

  • out (Tensor, 可选) – 输出张量。

示例

>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlog1py(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlog1py(x, y)
tensor([1.3863, 2.1972, 2.0794])
>>> torch.special.xlog1py(x, 4)
tensor([1.6094, 3.2189, 4.8283])
>>> torch.special.xlog1py(2, y)
tensor([2.7726, 2.1972, 1.3863])

torch.special.xlogy(input, other, *, out=None) → Tensor  

计算 input * log(other),具体有以下几种情况:

outi={NaN若 otheri=NaN0若 inputi=0.0inputi∗log⁡(otheri)其他情况\text{out}_{i} = \begin{cases}

\text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\

0 & \text{if } \text{input}_{i} = 0.0 \\

\text{input}{i} * \log{(\text{other}{i})} & \text{otherwise}

\end{cases}

outi​=⎩⎨⎧​NaN0inputi​∗log(otheri​)​若 otheri​=NaN若 inputi​=0.0其他情况​类似于 SciPy 的 scipy.special.xlogy 函数。

参数

  • input (Number* 或 Tensor) – 乘数
  • other (Number* 或 Tensor) – 参数

注意:inputother 中至少有一个必须是张量。

关键字参数

  • out (Tensor, 可选) – 输出张量。

示例:

>>> x = torch.zeros(5,)
>>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
>>> torch.special.xlogy(x, y)
tensor([0., 0., 0., 0., nan])
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([3, 2, 1])
>>> torch.special.xlogy(x, y)
tensor([1.0986, 1.3863, 0.0000])
>>> torch.special.xlogy(x, 4)
tensor([1.3863, 2.7726, 4.1589])
>>> torch.special.xlogy(2, y)
tensor([2.1972, 1.3863, 0.0000])

torch.special.zeta(input, other, *, out=None) → Tensor  

逐元素计算 Hurwitz zeta 函数。

$$\zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}

$$

参数

  • input ( Tensor ) – 对应 x 的输入张量。
  • other ( Tensor ) – 对应 q 的输入张量。

注意:当 q = 1 时即为黎曼 zeta 函数。

关键字参数

  • out ( Tensor , optional) – 输出张量。

示例:

>>> x = torch.tensor([2., 4.])
>>> torch.special.zeta(x, 1)
tensor([1.6449, 1.0823])
>>> torch.special.zeta(x, torch.tensor([1., 2.]))
tensor([1.6449, 0.0823])
>>> torch.special.zeta(2, torch.tensor([1., 2.]))
tensor([1.6449, 0.6449])


torch.overrides

该模块提供了多种辅助函数,用于支持__torch_function__协议。有关__torch_function__协议的更多详细信息,请参阅扩展torch Python API。


函数


torch.overrides.get_ignored_functions()

返回无法被__torch_function__覆盖的公共函数。

返回值:一个包含torch API中公开但无法通过__torch_function__覆盖的函数的元组。这主要是因为这些函数的参数都不是张量或类张量对象。

返回类型:set[Callable]


示例

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False

torch.overrides.get_overridable_functions()

可通过 __torch_function__ 重写的函数列表

返回值:一个字典,将包含可重写函数的命名空间映射到该命名空间中可被重写的函数。

返回类型:Dict[Any, List[Callable]]


torch.overrides.resolve_name(f)

获取传递给__torch_function__的函数的人类可读字符串名称


参数

  • f (Callable) – 需要解析名称的函数。

返回值:该函数的名称;如果对其进行求值,应能返回输入函数。

返回类型:str


torch.overrides.get_testing_overrides()

返回一个包含所有可覆盖函数的虚拟重载的字典

返回值:一个字典,将 PyTorch API 中的可覆盖函数映射到具有相同签名的 lambda 函数,这些 lambda 函数无条件返回 -1。这些 lambda 函数对于测试定义了 __torch_function__ 类型的 API 覆盖率非常有用。

返回类型:Dict[Callable, Callable]


示例

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>

torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)

实现一个检查__torch_function__重载的函数。

在C++实现中,与此函数等效的是torch::autograd::handle_torch_function。

参数

  • public_api (function) - 最初以public_api(args, *kwargs)形式调用的公开torch API函数,现在正在检查其参数。
  • relevant_args (iterable) - 需要检查__torch_function__方法的参数迭代器。
  • args (tuple) - 最初传入public_api的任意位置参数。
  • kwargs (tuple) - 最初传入public_api的任意关键字参数。

返回

根据情况返回调用implementation__torch_function__方法的结果。

返回类型

object

:raises TypeError: 如果找不到实现。


示例

>>> def func(a):
...     if has_torch_function_unary(a):
...         return handle_torch_function(func, (a,), a)
...     return a + 0

torch.overrides.has_torch_function() 

检查可迭代对象中的元素是否实现了__torch_function__,或者是否启用了__torch_function__模式。注意:精确的TensorParameter被视为不可调度类型。此方法用于保护对handle_torch_function()的调用,不要用它来检测对象是否类似Tensor——请改用is_tensor_like()

:param relevant_args: 需要检查__torch_function__方法的可迭代对象或参数
:type relevant_args: iterable

返回值

如果relevant_args中任何元素实现了__torch_function__则返回True,否则返回False。

返回类型 : bool

另请参阅

torch.is_tensor_like 检测对象是否为Tensor-like(包括精确的Tensor


torch.overrides.is_tensor_like(inp)

如果传入的输入是类张量(Tensor-like)对象,则返回True

当前实现中,只要输入对象的类型具有__torch_function__属性即视为类张量。

示例:
张量的子类通常属于类张量对象。


>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

内置类型或用户自定义类型通常不具备 Tensor 的特性。


>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

但是,可以通过实现 __torch_function__ 使它们具备类似张量的特性。


>>> class TensorLike:
...     @classmethod
...     def __torch_function__(cls, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True

torch.overrides.is_tensor_method_or_property(func)

如果传入的函数是 torch.Tensor 方法或属性的处理程序(如传入 __torch_function__ 时),则返回 True。

注意:对于属性,必须传入其 __get__ 方法。

这在以下情况下尤其需要:

1、方法/属性有时不包含 module 槽位
2、它们要求第一个传入参数必须是 torch.Tensor 的实例


示例

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False

返回类型:bool


torch.overrides.wrap_torch_function(dispatcher)

__torch_function__相关功能包装给定的函数。

参数

  • dispatcher (Callable) – 一个可调用对象,返回传入函数中的类Tensor对象的可迭代集合。

注意:此装饰器可能会降低代码性能。通常,将代码表达为一系列自身支持__torch_function__的函数就足够了。如果您遇到罕见情况(例如在封装底层库时也需要使其支持类Tensor对象),则可以使用此函数。


示例

>>> def dispatcher(a):  # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a):  # This will make func dispatchable by __torch_function__
...     return a + 0


torch.package

torch.package 提供了创建包含工件和任意 PyTorch 代码的包的支持。这些包可以被保存、共享,用于在之后的时间或不同的机器上加载和执行模型,甚至可以使用 torch::deploy 部署到生产环境。

本文档包含教程、操作指南、说明和 API 参考,将帮助您了解更多关于 torch.package 的信息以及如何使用它。


警告:此模块依赖于不安全的 pickle 模块。仅解包您信任的数据。

恶意构造的 pickle 数据可能会在解包过程中执行任意代码。切勿解包可能来自不受信任来源或可能被篡改的数据。

更多信息,请查阅 pickle 模块的文档。


教程


打包你的第一个模型

我们提供了一个教程,引导你完成打包和解包一个简单模型的流程,该教程可在 Colab 上查看。完成这个练习后,你将熟悉创建和使用 Torch 包的基本 API。


如何实现…


查看包内包含哪些内容?


将包视为ZIP归档文件处理

torch.package的容器格式采用ZIP标准,因此任何适用于标准ZIP文件的工具都能用于查看其内容。以下是操作ZIP文件的常用方法:

  • 执行unzip my_package.pt命令可将torch.package归档解压到磁盘,便于自由检查其内容。

$ unzip my_package.pt && tree my_package
my_package
├── .data
│   ├── 94304870911616.storage
│   ├── 94304900784016.storage
│   ├── extern_modules
│   └── version
├── models
│   └── model_1.pkl
└── torchvision└── models├── resnet.py└── utils.py
~ cd my_package && cat torchvision/models/resnet.py
...

  • Python的zipfile模块提供了读写ZIP归档文件内容的标准方法。

from zipfile import ZipFile with ZipFile("my_package.pt") as myzip:file_bytes = myzip.read("torchvision/models/resnet.py")# edit file_bytes in some waymyzip.writestr("torchvision/models/resnet.py", new_file_bytes)

  • Vim 原生支持读取 ZIP 压缩包。你甚至可以直接编辑文件并通过 :write 命令将修改写回压缩包!

# add this to your .vimrc to treat `*.pt` files as zip files
au BufReadCmd *.pt call zip#Browse(expand("<amatch>"))~ vi my_package.pt

使用 file_structure() API

PackageImporter 提供了一个 file_structure() 方法,该方法会返回一个可打印且可查询的 Directory 对象。Directory 对象是一个简单的目录结构,可用于查看 torch.package 的当前内容。

Directory 对象本身可以直接打印,并会输出文件树的表示形式。如需过滤返回的内容,可使用 glob 风格的 includeexclude 过滤参数。


with PackageExporter('my_package.pt') as pe:pe.save_pickle('models', 'model_1.pkl', mod)importer = PackageImporter('my_package.pt')
# can limit printed items with include/exclude args
print(importer.file_structure(include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"))
print(importer.file_structure()) # will print out all files

Output:


# filtered with glob pattern:
#    include=["**/utils.py", "**/*.pkl"], exclude="**/*.storage"
─── my_package.pt├── models│   └── model_1.pkl└── torchvision└── models└── utils.py# all files
─── my_package.pt├── .data│   ├── 94304870911616.storage│   ├── 94304900784016.storage│   ├── extern_modules│   └── version├── models│   └── model_1.pkl└── torchvision└── models├── resnet.py└── utils.py

你也可以使用 has_file() 方法查询 Directory 对象。


importer_file_structure = importer.file_structure()
found: bool = importer_file_structure.has_file("package_a/subpackage.py")

查看某个模块为何被列为依赖项?

假设有一个模块 foo,你想知道为什么 PackageExporter 会将其作为依赖项引入。

PackageExporter.get_rdeps() 方法会返回所有直接依赖 foo 的模块。

如果想查看特定模块 src 如何依赖 foo,可以使用 PackageExporter.all_paths() 方法,该方法会返回一个 DOT 格式的图表,展示 srcfoo 之间的所有依赖路径。

如果只想查看 PackageExporter 的完整依赖关系图,可以使用 PackageExporter.dependency_graph_string() 方法。


如何在打包时包含任意资源并后续访问?

PackageExporter 提供了三个方法:save_picklesave_textsave_binary,允许你将 Python 对象、文本和二进制数据保存到包中。


with torch.PackageExporter("package.pt") as exporter:# Pickles the object and saves to `my_resources/tensor.pkl` in the archive.exporter.save_pickle("my_resources", "tensor.pkl", torch.randn(4))exporter.save_text("config_stuff", "words.txt", "a sample string")exporter.save_binary("raw_data", "binary", my_bytes)

PackageImporter 提供了三个互补方法:load_pickleload_textload_binary,用于从包中加载 Python 对象、文本数据和二进制数据。


importer = torch.PackageImporter("package.pt")
my_tensor = importer.load_pickle("my_resources", "tensor.pkl")
text = importer.load_text("config_stuff", "words.txt")
binary = importer.load_binary("raw_data", "binary")

自定义类的打包方式

torch.package 允许自定义类的打包方式。这一行为通过以下两种方式实现:在类上定义方法 __reduce_package__,并定义对应的解包函数。这与为 Python 常规的 pickle 过程定义 __reduce__ 类似。

操作步骤:

1、在目标类上定义方法 __reduce_package__(self, exporter: PackageExporter)。该方法负责将类实例保存到包中,并应返回一个元组,包含对应的解包函数及调用该函数所需的参数。当 PackageExporter 遇到目标类的实例时,会调用此方法。
2、为类定义一个解包函数。该解包函数负责重建并返回类的实例。其函数签名的第一个参数应为 PackageImporter 实例,其余参数由用户自定义。


# foo.py [Example of customizing how class Foo is packaged]
from torch.package import PackageExporter, PackageImporter
import timeclass Foo:def __init__(self, my_string: str):super().__init__()self.my_string = my_stringself.time_imported = 0self.time_exported = 0def __reduce_package__(self, exporter: PackageExporter):"""Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` whensaving an instance of this object. This method should do the work to save this         object inside of the ``torch.package`` archive.Returns function w/ arguments to load the object from a         ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function."""# use this pattern to ensure no naming conflicts with normal dependencies,  # anything saved under this module name shouldn't conflict with other# items in the packagegenerated_module_name = f"foo-generated._{exporter.get_unique_id()}"exporter.save_text(generated_module_name,      "foo.txt",      self.my_string + ", with exporter modification!",  )time_exported = time.clock_gettime(1)# returns de-packaging function w/ arguments to invoke with         return (unpackage_foo, (generated_module_name, time_exported,))def unpackage_foo(importer: PackageImporter, generated_module_name: str, time_exported: float
) -Foo:"""Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` functionwhen depickling a Foo object.Performs work of loading and returning a Foo instance from a ``torch.package`` archive."""time_imported = time.clock_gettime(1)foo = Foo(importer.load_text(generated_module_name, "foo.txt"))foo.time_imported = time_importedfoo.time_exported = time_exportedreturn foo

# example of saving instances of class Fooimport torch
from torch.package import PackageImporter, PackageExporter
import foofoo_1 = foo.Foo("foo_1 initial string")
foo_2 = foo.Foo("foo_2 initial string") with PackageExporter('foo_package.pt') as pe:# save as normal, no extra work necessarype.save_pickle('foo_collection', 'foo1.pkl', foo_1)pe.save_pickle('foo_collection', 'foo2.pkl', foo_2)pi = PackageImporter('foo_package.pt')
print(pi.file_structure())
imported_foo = pi.load_pickle('foo_collection', 'foo1.pkl')
print(f"foo_1 string: '{imported_foo.my_string}'")
print(f"foo_1 export time: {imported_foo.time_exported}")
print(f"foo_1 import time: {imported_foo.time_imported}")

# output of running above script
─── foo_package├── foo-generated│   ├── _0│   │   └── foo.txt│   └── _1│       └── foo.txt├── foo_collection│   ├── foo1.pkl│   └── foo2.pkl└── foo.pyfoo_1 string: 'foo_1 initial string, with reduction modification!'
foo_1 export time: 9857706.650140837
foo_1 import time: 9857706.652698385

如何在源码中检测当前是否运行在包环境中?

PackageImporter 会在初始化每个模块时为其添加 __torch_package__ 属性。你的代码可以通过检查该属性是否存在,来判断当前是否处于打包后的运行环境中。


# In foo/bar.py:if "__torch_package__" in dir():  # true if the code is being loaded from a packagedef is_in_package():return TrueUserException = Exception
else:def is_in_package():return FalseUserException = UnpackageableException

现在,代码的行为会根据它是通过Python环境正常导入还是从torch.package导入而有所不同。


from foo.bar import is_in_packageprint(is_in_package())  # Falseloaded_module = PackageImporter(my_package).import_module("foo.bar")
loaded_module.is_in_package()  # True

警告:通常情况下,让代码在打包前后表现不一致是一种不良实践。这会导致难以调试的问题,且问题表现会因代码导入方式的不同而敏感变化。如果你的包预计会被频繁使用,建议重构代码,确保无论以何种方式加载,其行为都保持一致。


如何将代码补丁打入包中?

PackageExporter 提供了 save_source_string() 方法,允许你将任意 Python 源代码保存到指定的模块中。


with PackageExporter(f) as exporter:# Save the my_module.foo available in your current Python environment.exporter.save_module("my_module.foo")# This saves the provided string to my_module/foo.py in the package archive.# It will override the my_module.foo that was previously saved.exporter.save_source_string("my_module.foo", textwrap.dedent("""\def my_function():print('hello world')"""))# If you want to treat my_module.bar as a package# (e.g. save to `my_module/bar/__init__.py` instead of `my_module/bar.py)# pass is_package=True, exporter.save_source_string("my_module.bar",        "def foo(): print('hello')\n",        is_package=True)importer = PackageImporter(f)
importer.import_module("my_module.foo").my_function()  # prints 'hello world'

如何从打包代码中访问包内容?

PackageImporter 实现了 importlib.resources API,用于从包内部访问资源。


with PackageExporter(f) as exporter:# saves text to my_resource/a.txt in the archiveexporter.save_text("my_resource", "a.txt", "hello world!")# saves the tensor to my_pickle/obj.pklexporter.save_pickle("my_pickle", "obj.pkl", torch.ones(2, 2))# see below for module contentsexporter.save_module("foo")exporter.save_module("bar")

importlib.resources API 允许从打包代码中访问资源。


# foo.py:
import importlib.resources
import my_resource# returns "hello world!"
def get_my_resource():return importlib.resources.read_text(my_resource, "a.txt")

推荐使用 importlib.resources 来访问打包代码中的包内容,因为它符合 Python 标准。不过,也可以直接从打包代码中访问父级 PackageImporter 实例本身。


# bar.py:
import torch_package_importer # this is the PackageImporter that imported this module.# Prints "hello world!", equivalent to importlib.resources.read_text
def get_my_resource():return torch_package_importer.load_text("my_resource", "a.txt")# You also do things that the importlib.resources API does not support, like loading
# a pickled object from the package.
def get_my_pickle():return torch_package_importer.load_pickle("my_pickle", "obj.pkl")

区分打包代码与非打包代码

要判断一个对象的代码是否来自 torch.package,可使用 torch.package.is_from_package() 函数。
注意:若对象来自某个包但其定义源自标记为 extern 的模块或来自 stdlib,此检查将返回 False


importer = PackageImporter(f)
mod = importer.import_module('foo')
obj = importer.load_pickle('model', 'model.pkl')
txt = importer.load_text('text', 'my_test.txt')assert is_from_package(mod)
assert is_from_package(obj)
assert not is_from_package(txt) # str is from stdlib, so this will return False

如何重新导出已导入的对象?

要通过新的 PackageExporter 重新导出之前由 PackageImporter 导入的对象,必须让导出器知晓原始导入器的存在,这样才能正确找到对象依赖项的源代码。


importer = PackageImporter(f)
obj = importer.load_pickle("model", "model.pkl")# re-export obj in a new package with PackageExporter(f2, importer=(importer, sys_importer)) as exporter:exporter.save_pickle("model", "model.pkl", obj)

如何打包 TorchScript 模块?

要打包 TorchScript 模型,可以使用与其他对象相同的 save_pickleload_pickle API。TorchScript 对象作为属性或子模块时也支持直接保存,无需额外操作。


# save TorchScript just like any other object with PackageExporter(file_name) as e:e.save_pickle("res", "script_model.pkl", scripted_model)e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule)
# load as normal
importer = PackageImporter(file_name)
loaded_script = importer.load_pickle("res", "script_model.pkl")
loaded_mixed = importer.load_pickle("res", "mixed_model.pkl"

说明


torch.package 格式概述

torch.package 文件是一个 ZIP 归档文件,通常使用 .pt 扩展名。ZIP 归档内包含两类文件:

  • 框架文件:存放在 .data/ 目录下
  • 用户文件:其余所有文件

例如,以下是一个完整打包的 torchvision ResNet 模型的结构示例:

resnet
├── .data  # All framework-specific data is stored here.
│   │      # It's named to avoid conflicts with user-serialized code.
│   ├── 94286146172688.storage  # tensor data
│   ├── 94286146172784.storage
│   ├── extern_modules  # text file with names of extern modules (e.g. 'torch')
│   ├── version         # version metadata
│   ├── ...
├── model  # the pickled model
│   └── model.pkl
└── torchvision  # all code dependencies are captured as source files└── models├── resnet.py└── utils.py

框架文件

.data/ 目录由 torch.package 所有,其内容被视为私有实现细节。torch.package 格式不保证 .data/ 目录内容的具体结构,但所有改动都将保持向后兼容性(即新版本的 PyTorch 始终能够加载旧版 torch.package)。

目前,.data/ 目录包含以下内容:

  • version:序列化格式的版本号,用于让 torch.package 的导入基础设施知道如何加载该包。
  • extern_modules:被视为 extern 的模块列表。extern 模块将使用加载环境的系统导入器进行导入。
  • *.storage:序列化的张量数据。

.data
├── 94286146172688.storage
├── 94286146172784.storage
├── extern_modules
├── version
├── ...

用户文件

归档中的所有其他文件都是由用户放置的。其目录结构与 Python 的常规包完全一致。若想深入了解 Python 的打包机制,请参阅这篇文章(内容稍有过时,具体实现细节请以 Python 参考文档为准)。


<package root>
├── model  # the pickled model
│   └── model.pkl
├── another_package
│   ├── __init__.py
│   ├── foo.txt         # a resource file , see importlib.resources
│   └── ...
└── torchvision└── models├── resnet.py   # torchvision.models.resnet└── utils.py    # torchvision.models.utils

torch.package 如何查找代码依赖项


分析对象的依赖关系

当你调用 save_pickle(obj, ...) 时,PackageExporter 会正常地对对象进行 pickle 序列化。随后,它会使用标准库模块 pickletools 来解析 pickle 字节码。

在 pickle 序列化过程中,对象会与一个 GLOBAL 操作码一起保存,该操作码描述了如何找到对象类型的实现位置,例如:

GLOBAL 'torchvision.models.resnet Resnet` 

依赖解析器会收集所有GLOBAL操作,并将它们标记为待序列化对象的依赖项。

有关序列化及pickle格式的更多信息,请参阅Python官方文档。


分析模块依赖关系

当识别出某个Python模块作为依赖项时,torch.package会遍历该模块的Python抽象语法树(AST)表示,并查找其中的导入语句。它完全支持标准导入形式:from x import yimport zfrom w import v as u等。当遇到这些导入语句时,torch.package会将导入的模块注册为依赖项,随后以同样的AST遍历方式解析这些依赖模块。

注意:AST解析对__import__(...)语法的支持有限,且不支持importlib.import_module调用。通常不应期望torch.package能检测到动态导入。


依赖管理

torch.package 会自动发现你的代码和对象所依赖的 Python 模块。这一过程称为依赖解析。
对于依赖解析器找到的每个模块,你必须指定一个操作来处理它。

允许的操作包括:

  • intern:将该模块打包到包中。
  • extern:声明该模块为包的外部依赖项。
  • mock:将该模块替换为存根。
  • deny:依赖此模块会在包导出时引发错误。

此外,还有一个重要的操作虽然技术上不属于 torch.package 的一部分:

  • 重构:移除或更改代码中的依赖项。

注意,操作仅针对整个 Python 模块定义,无法仅打包模块中的某个函数或类而忽略其余部分。
这是有意为之的设计。Python 并未提供模块内对象之间的清晰边界,模块是依赖组织的唯一明确定义单元,因此 torch.package 也采用这一标准。

操作通过模式应用于模块。模式可以是模块名称(如 "foo.bar")或通配符(如 "foo.**")。你可以使用 PackageExporter 上的方法将模式与操作关联,例如:

my_exporter.intern("torchvision.**")
my_exporter.extern("numpy")

如果模块匹配某个模式,就会对其应用相应的操作。对于给定的模块,系统会按照模式定义的顺序依次检查,并执行第一个匹配的操作。


intern

如果一个模块被标记为intern,它将被放入包中。

此操作适用于你的模型代码,或任何你想打包的相关代码。例如,如果你要打包torchvision中的ResNet模型,就需要将模块torchvision.models.resnet标记为intern

当导入包时,如果打包的代码尝试导入一个被标记为intern的模块,PackageImporter会在你的包内查找该模块。如果找不到该模块,则会抛出错误。这确保了每个PackageImporter与加载环境隔离——即使my_interned_module同时在你的包和加载环境中可用,PackageImporter也只会使用包内的版本。

注意:只有Python源码模块可以被标记为intern。其他类型的模块(如C扩展模块和字节码模块)如果尝试标记为intern会抛出错误。这类模块需要被标记为mockextern


extern

如果一个模块被声明为extern,它将不会被包含在包中。相反,该模块会被添加到当前包的外部依赖列表中。你可以在package_exporter.extern_modules中找到这个列表。

当导入包时,如果打包后的代码尝试导入一个被声明为extern的模块,PackageImporter会使用默认的Python导入器来查找该模块,就像执行了importlib.import_module("my_externed_module")一样。如果找不到该模块,则会抛出错误。

通过这种方式,你可以在包中依赖第三方库(如numpyscipy),而无需将它们一并打包。

警告:如果任何外部库发生了不向后兼容的更改,你的包可能无法加载。如果需要长期保证包的复现性,请尽量减少使用extern


mock

当一个模块被mock时,它不会被打包。取而代之的是一个存根模块会被打包到该位置。这个存根模块允许你从中获取对象(因此from my_mocked_module import foo不会报错),但任何使用该对象的操作都会引发NotImplementedError

mock应该用于那些你"确定"在加载的包中不需要,但仍希望在非打包内容中可用的代码。例如初始化/配置代码,或仅用于调试/训练的代码。

警告:一般来说,mock应该作为最后手段使用。它会导致打包代码和非打包代码之间的行为差异,可能引发后续混淆。建议优先通过重构代码来移除不需要的依赖项。


代码重构

管理依赖关系的最佳方式就是彻底消除依赖!通过重构代码,我们往往可以移除不必要的依赖。以下是编写低依赖代码的指导原则(这些原则本身也是优秀的编码实践):

只引入真正用到的内容。不要在代码中保留未使用的导入项。依赖解析器无法智能识别这些未使用的导入,仍会尝试处理它们。

精确限定导入范围。例如,与其写import foo然后在代码中使用foo.bar.baz,不如直接写from foo.bar import baz。这种方式能更精准地声明实际依赖(foo.bar),让依赖解析器明确你不需要引入整个foo包。

将包含无关功能的大文件拆分为小模块。如果你的utils模块混杂了大量无关功能,任何依赖该模块的代码都不得不引入许多无关依赖——即使你只需要其中一小部分功能。更好的做法是创建功能单一的小模块,这些模块可以彼此独立地打包。


模式

模式允许您通过简洁的语法来指定模块组。其语法和行为遵循 Bazel/Buck 的 glob() 规范。

我们将尝试与模式匹配的模块称为候选对象。候选对象由通过分隔符字符串分隔的多个段组成,例如 foo.bar.baz

一个模式包含一个或多个段。段可以是以下类型:

  • 字面字符串(如 foo),表示精确匹配
  • 包含通配符的字符串(如 torchfoo*baz*)。通配符可以匹配任意字符串,包括空字符串
  • 双通配符(**)。这将匹配零个或多个完整段

示例:

  • torch.**:匹配 torch 及其所有子模块,例如 torch.nntorch.nn.functional
  • torch.*:匹配 torch.nntorch.functional,但不匹配 torch.nn.functionaltorch
  • torch*.**:匹配 torchtorchvision 及其所有子模块

在指定操作时,您可以传入多个模式,例如


exporter.intern(["torchvision.models.**", "torchvision.utils.**"])

模块将匹配此操作,只要符合其中任一模式。

您还可以指定要排除的模式,例如


exporter.mock("**", exclude=["torchvision.**"])

如果模块匹配任何排除模式,则该模块不会与此操作匹配。在本示例中,我们模拟了除 torchvision 及其子模块之外的所有模块。

当一个模块可能匹配多个操作时,将优先采用第一个定义的操作。


torch.package 的注意事项


避免在模块中使用全局状态

Python 可以非常方便地在模块作用域内绑定对象和运行代码。这通常没有问题——毕竟函数和类也是通过这种方式绑定到名称的。但当你定义一个模块作用域内的可变对象时,就会引入可变的全局状态,情况就变得复杂了。

可变全局状态非常有用——它可以减少样板代码、允许开放注册到表中等等。但除非非常谨慎地使用,否则在与 torch.package 一起使用时可能会导致问题。

每个 PackageImporter 都会为其内容创建一个独立的环境。这很好,因为它意味着我们可以加载多个包并确保它们彼此隔离。但当模块的编写方式假设存在共享的可变全局状态时,这种行为可能会导致难以调试的错误。


类型在包与加载环境之间不共享

通过 PackageImporter 导入的任何类,都将是该导入器特有的类版本。例如:

from foo import MyClassmy_class_instance = MyClass()with PackageExporter(f) as exporter:exporter.save_module("foo")importer = PackageImporter(f)
imported_MyClass = importer.import_module("foo").MyClassassert isinstance(my_class_instance, MyClass)  # works
assert isinstance(my_class_instance, imported_MyClass)  # ERROR!

在这个示例中,MyClassimported_MyClass 不是同一类型。虽然在这个特定例子中,MyClassimported_MyClass 的实现完全相同,你可能会认为它们可以视为同一个类。但设想一下,如果 imported_MyClass 来自一个旧版本的包,其中 MyClass 的实现完全不同——这种情况下,将它们视为同一个类是不安全的。

实际上,每个导入器都有一个唯一标识类的前缀:

print(MyClass.__name__)  # prints "foo.MyClass"
print(imported_MyClass.__name__)  # prints <torch_package_0>.foo.MyClass

这意味着当其中一个参数来自某个包而另一个不是时,你不应期望isinstance检查能正常工作。如果需要此功能,请考虑以下选项:

  • 采用鸭子类型(直接使用类而非显式检查其是否属于给定类型)。
  • 将类型关系明确作为类契约的一部分。例如,可以添加属性标签self.handler = "handle_me_this_way",并让客户端代码检查handler的值而非直接检查类型。

torch.package 如何实现包之间的隔离

每个 PackageImporter 实例都会为其模块和对象创建一个独立的隔离环境。包中的模块只能导入其他已打包的模块或被标记为 extern 的模块。如果使用多个 PackageImporter 实例加载同一个包,将会得到多个互不干扰的独立环境。

这一机制是通过扩展 Python 的导入系统实现的,PackageImporter 是一个自定义导入器。它提供了与 importlib 导入器相同的核心 API,即实现了 import_module__import__ 方法。

当调用 PackageImporter.import_module() 时,PackageImporter 会像系统导入器一样构造并返回一个新模块。不同之处在于,它会修改返回的模块,使其在后续导入请求时使用当前 PackageImporter 实例(即从包内查找资源),而不是从用户的 Python 环境中搜索。


名称修饰(Mangling)

为了避免混淆(“这个 foo.bar 对象是来自我的包,还是来自 Python 环境?”),PackageImporter 会通过添加修饰前缀来修改所有导入模块的 __name____file__ 属性。

对于 __name__,像 torchvision.models.resnet18 这样的名称会被修饰为 <torch_package_0>.torchvision.models.resnet18

对于 __file__,像 torchvision/models/resnet18.py 这样的路径会被修饰为 <torch_package_0>.torchvision/modules/resnet18.py

名称修饰有助于避免不同包之间模块名的意外冲突,并通过使堆栈跟踪和打印语句更清晰地显示它们是否引用打包代码来辅助调试。有关名称修饰的开发者详细信息,请参阅 torch/package/ 目录下的 mangling.md 文件。


API 参考


class torch.package.PackagingError(dependency_graph, debug=False)

当导出包出现问题时,会引发此异常。

PackageExporter 会尝试收集所有错误并一次性展示给您。


class torch.package.EmptyMatchError

当模拟对象(mock)或外部依赖(extern)被标记为allow_empty=False,但在打包过程中未匹配到任何模块时,会抛出此异常。


class torch.package.PackageExporter(f, importer=<torch.package.importer._SysImporter object>, debug=False)

导出器(Exporters)允许你将代码包、序列化的Python数据以及任意二进制和文本资源打包成自包含的独立包。

导入器(Imports)能够以封闭方式加载这些代码,确保代码从包内加载而非通过常规Python导入系统。这种机制使得PyTorch模型代码和数据可以被打包,以便在服务器上运行或用于未来的迁移学习。

包中的代码在创建时会从原始源文件逐份复制,其文件格式采用特殊组织的zip压缩包。包的使用者可以解压该包并编辑代码,以实现自定义修改。

包的导入器会确保模块代码只能从包内加载,除非模块通过extern()明确声明为外部依赖。zip压缩包中的extern_modules文件列出了该包所有外部依赖的模块。

这种机制避免了"隐式"依赖问题——即包在本地运行时能正常导入本地安装的依赖,但当包被复制到其他机器时就会运行失败。

当源代码被添加到包中时,导出器可选择性地扫描代码以发现更多依赖关系(通过设置dependencies=True)。它会查找import语句,将相对引用解析为完整模块名,并执行用户指定的操作(参见:extern()mock()intern())。


__init__(f, importer=<torch.package.importer._SysImporter object>, debug=False)

创建一个导出器。

参数

  • f ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ]]) - 导出目标位置。可以是包含文件名的string/Path对象,也可以是二进制I/O对象。
  • importer ( Union [Importer*,* Sequence [Importer]]) - 如果传入单个Importer,则使用它来搜索模块。如果传入一个importer序列,则会基于它们构建一个OrderedImporter
  • debug ([bool]) - 如果设为True,会将损坏模块的路径添加到PackagingErrors中。

add_dependency(module_name, dependencies=True)

根据用户指定的模式,将给定模块添加到依赖关系图中。


all_paths(src, dst)

返回从源节点到目标节点所有路径的子图点表示形式。

返回值:包含从源节点到目标节点所有路径的点表示字符串。

参考文档:Graphviz语言规范

返回类型:str


close()

将包写入文件系统。调用close()方法后,所有后续操作都将无效。

建议改用资源守卫语法:

with PackageExporter("file.zip") as e:...

denied_modules()

返回当前被拒绝的所有模块。

返回值:一个包含该包中被拒绝模块名称的列表。

返回类型:list [str]


deny(include, *, exclude=())

根据给定的glob模式,从包可导入的模块列表中屏蔽匹配名称的模块。

如果发现任何匹配包的依赖项,将抛出 PackagingError 错误。

参数

  • include (Union[List[str],* str]) - 可以是字符串(例如 "my_package.my_subpackage")或模块名称列表,用于指定需要外部化的模块。也支持glob风格的模式,详见 mock()
  • exclude (Union[List[str],* str]) - 可选参数,用于排除与include字符串匹配的部分模式。

dependency_graph_string()

返回包中依赖关系的有向图字符串表示形式。

返回值:包中依赖关系的字符串表示形式。

返回类型:str


extern(include, *, exclude=(), allow_empty=True)

module 包含在包可导入的外部模块列表中。

这将阻止依赖项发现机制将其保存到包中。导入器会直接从标准导入系统加载外部模块。

外部模块的代码也必须存在于加载包的进程中。

参数

  • include (Union[List[str],* str]) – 字符串(例如 "my_package.my_subpackage")或待外部化的模块名称字符串列表。也可使用通配符模式,如 mock() 所述。
  • exclude (Union[List[str],* str]) – 可选参数,用于排除与 include 字符串匹配的部分模式。
  • allow_empty ([bool]) – 可选标志,指定本次调用 extern 方法所定义的外部模块是否必须在打包时匹配到某些模块。若以 allow_empty=False 添加外部模块通配符模式,且在匹配到任何模块前调用 close()(显式调用或通过 __exit__ 调用),则会抛出异常。若 allow_empty=True 则不会抛出此类异常。

(注:严格保留所有代码块、术语标记及链接格式,技术术语如externmock()等未作翻译,被动语态已转换为主动表达,长句进行了合理拆分)


externed_modules()

返回当前被外部化的所有模块。

返回值:一个包含该包中将被外部化的模块名称的列表。

返回类型:list [str]


get_rdeps(module_name)

返回所有依赖于模块 module_name 的模块列表。

返回值:一个包含依赖 module_name 的模块名称列表。

返回类型:list [str]


get_unique_id()

获取一个ID。该ID保证在此包中只会被分配一次。

返回类型:str


intern(include, *, exclude=(), allow_empty=True)

指定需要打包的模块。模块必须匹配某些intern模式才能被包含在包中,并递归处理其依赖项。

参数

  • include (Union[List[str],* str]) – 可以是字符串(例如"my_package.my_subpackage")或模块名称列表,用于指定需要外部化的模块。该参数也支持glob风格的模式匹配,具体说明见mock()文档。
  • exclude (Union[List[str],* str]) – 可选参数,用于排除与include模式匹配的特定模式。
  • allow_empty ([bool]) – 可选标志,指定通过intern方法设置的内部模块在打包时是否必须匹配到某些模块。如果添加intern模块glob模式时设置allow_empty=False,且在调用close()(显式调用或通过__exit__)时没有任何模块匹配该模式,则会抛出异常。若设置allow_empty=True,则不会抛出此类异常。

interned_modules()

返回当前被内部化的所有模块。

返回值:一个包含将被此包内部化的模块名称的列表。

返回类型:list [str]


mock(include, *, exclude=(), allow_empty=True)

用模拟实现替换某些必需的模块。被模拟的模块将返回一个虚假对象,用于处理任何从其访问的属性。由于我们采用逐文件复制的方式,依赖解析有时会找到被模型文件导入但实际功能从未使用的文件(例如自定义序列化代码或训练辅助工具)。

使用此函数可以模拟这些功能,而无需修改原始代码。

参数

  • include (Union[List[str],* str])

一个字符串(例如 "my_package.my_subpackage")或字符串列表,表示需要被模拟的模块名称。字符串也可以使用通配符模式,可能匹配多个模块。任何符合此模式字符串的必需依赖项都将被自动模拟。

示例:
'torch.**' – 匹配 torch 及其所有子模块,例如 'torch.nn''torch.nn.functional'
'torch.*' – 匹配 'torch.nn''torch.functional',但不匹配 'torch.nn.functional'

  • exclude (Union[List[str],* str]) – 一个可选模式,用于排除某些匹配 include 字符串的模式。
    例如,include='torch.**', exclude='torch.foo' 将模拟除 'torch.foo' 之外的所有 torch 包。默认值为 []
  • allow_empty ([bool]) – 一个可选标志,指定通过调用 mock() 方法指定的模拟实现是否必须在打包过程中匹配到某些模块。如果以 allow_empty=False 添加模拟,并且在调用 close()(显式调用或通过 __exit__)时该模拟未匹配到导出包所使用的模块,则会抛出异常。

如果 allow_empty=True,则不会抛出此类异常。


mocked_modules()

返回当前被模拟的所有模块。

返回值:包含本包中将被模拟的模块名称列表。

返回类型:list [str]


register_extern_hook(hook)

在导出器上注册一个外部钩子。

每当有模块匹配 extern() 模式时,就会调用该钩子。

钩子函数需要遵循以下签名:

hook(exporter: PackageExporter, module_name: str) -None

钩子将按照注册顺序被调用。

返回值:一个句柄,可用于通过调用 handle.remove() 来移除已添加的钩子。

返回类型:torch.utils.hooks.RemovableHandle


register_intern_hook(hook)

在导出器上注册一个内部钩子。

每当模块匹配 intern() 模式时,都会调用该钩子。

它应具有以下签名:

hook(exporter: PackageExporter, module_name: str) -None

钩子将按照注册顺序被调用。

返回值:一个句柄,可用于通过调用 handle.remove() 来移除已添加的钩子。

返回类型:torch.utils.hooks.RemovableHandle


register_mock_hook(hook)

在导出器上注册一个模拟钩子。

每当模块匹配 mock() 模式时,该钩子就会被调用。

它应具有以下签名:

hook(exporter: PackageExporter, module_name: str) -None

钩子函数将按照注册顺序依次调用。

返回值:返回一个句柄,可通过调用 handle.remove() 来移除已添加的钩子。

返回类型:torch.utils.hooks.RemovableHandle


save_binary(package, resource, binary)

将原始字节保存到包中。

参数

  • package (str) – 该资源所属的模块包名称(例如 "my_package.my_subpackage")。
  • resource (str) – 资源的唯一名称,用于加载时识别。
  • binary (str) – 要保存的数据。

save_module(module_name, dependencies=True)

module 的代码保存到包中。模块代码的解析过程是:先通过 importers 路径查找模块对象,然后利用其 __file__ 属性定位源代码。

参数

  • module_name (str) – 例如 my_package.my_subpackage,代码将被保存以提供该包的实现代码。
  • dependencies ([bool], 可选) – 若设为 True,则会扫描源代码中的依赖项。

save_pickle(package, resource, obj, dependencies=True, pickle_protocol=3)

使用pickle将Python对象保存到归档文件中。功能等同于torch.save(),但会保存到归档而非独立文件。标准pickle不会保存代码,仅保存对象。

如果dependencies参数为True,此方法还会扫描被pickle的对象,识别重建它们所需的模块,并保存相关代码。

要保存type(obj).__name__my_module.MyObject的对象时,my_module.MyObject必须能根据importer顺序解析为对象的类。当保存先前已打包的对象时,importer列表中必须包含import_module方法才能正常工作。

参数

  • package (str) - 该资源所属的模块包名称(例如"my_package.my_subpackage")。
  • resource (str) - 资源的唯一名称,用于加载时识别。
  • obj (Any) - 要保存的对象,必须可被pickle序列化。
  • dependencies ([bool], 可选) - 若为True,则会扫描源代码中的依赖项。

save_source_file(module_name, file_or_directory, dependencies=True)

将本地文件系统中的 file_or_directory 添加到源码包中,为 module_name 提供代码。

参数

  • module_name (str) – 例如 "my_package.my_subpackage",代码将被保存以提供该包的代码。
  • file_or_directory (str) – 代码文件或目录的路径。如果是目录,将递归复制该目录中的所有 Python 文件,使用 save_source_file()。如果文件名为 "/__init__.py",则该代码被视为一个包。
  • dependencies ([bool], 可选) – 如果为 True,则会扫描源码中的依赖项。

save_source_string(module_name, src, is_package=False, dependencies=True)

src作为导出包中module_name的源代码添加。

参数

  • module_name (str) – 例如my_package.my_subpackage,代码将被保存以提供该包的源代码。
  • src (str) – 要为此包保存的Python源代码。
  • is_package ([bool], 可选) – 如果为True,则将此模块视为包。包允许包含子模块(例如my_package.my_subpackage.my_subsubpackage),并且可以在其中保存资源。默认为False
  • dependencies ([bool], 可选) – 如果为True,则会扫描源代码中的依赖项。

save_text(package, resource, text)

将文本数据保存到包中。

参数

  • package (str) – 该资源所属模块包的名称(例如 "my_package.my_subpackage")。
  • resource (str) – 资源的唯一名称,用于加载时识别。
  • text (str) – 要保存的内容。

class torch.package.PackageImporter(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)

导入器(Importers)允许您加载由PackageExporter写入包的代码。

代码以密封方式加载,使用包中的文件而非常规Python导入系统。这使得PyTorch模型代码和数据可以被打包,从而能在服务器上运行或用于未来的迁移学习。

包导入器确保模块中的代码只能从包内部加载,除非是导出时明确列为外部的模块。

zip存档中的extern_modules文件列出了包外部依赖的所有模块。

这避免了"隐式"依赖问题——即包在本地运行时能正常工作(因为导入了本地安装的包),但当包被复制到其他机器时就会失败。


__init__(file_or_buffer, module_allowed=<function PackageImporter.<lambda>>)

打开 file_or_buffer 以进行导入操作。此操作会检查导入的包是否仅依赖 module_allowed 所允许的模块。

参数

  • file_or_buffer ( Union [str,* PathLike[str],* [IO](https://docs.python.org/3/library/typing.html#typing.IO "(in Python v3.13)")[bytes ], PyTorchFileReader]) – 类文件对象(需实现 read()readline()tell()seek() 方法)、字符串或包含文件名的 os.PathLike 对象。
  • module_allowed (Callable[[str],* [bool]], optional) – 用于判断是否应允许外部提供模块的方法。可用于确保加载的包不依赖服务器不支持的模块。默认允许所有模块。

抛出异常

ImportError` – 如果包尝试使用被禁止的模块。


file_structure(*, include='**', exclude=())

返回包 zipfile 的文件结构表示。

参数

  • include (Union[List[str],* str]) – 可选字符串(如 "my_package.my_subpackage")或可选字符串列表,用于指定要包含在 zipfile 表示中的文件名。也可以是 glob 风格的模式,如 PackageExporter.mock() 中所述。
  • exclude (Union[List[str],* str]) – 可选模式,用于排除名称匹配该模式的文件。

返回

Directory

返回类型:Directory

id()

返回 torch.package 用于区分 PackageImporter 实例的内部标识符。

格式类似:

<torch_package_0>

import_module(name, package=None)

如果模块尚未加载,则从包中加载该模块并返回。模块会被加载到导入者的本地命名空间,并出现在 self.modules 中而非 sys.modules 里。

参数

  • name (str) – 要加载模块的完整限定名。
  • package ([type], 可选) – 未使用,但为了与 importlib.import_module 的函数签名保持一致而保留。默认为 None

返回值

(可能已加载的)模块对象。

返回类型

types.ModuleType

load_binary(package, resource)

加载原始字节数据。

参数

  • package (str) – 模块包的名称(例如 "my_package.my_subpackage")。
  • resource (str) – 资源的唯一名称。

返回值:已加载的数据。

返回类型:bytes


load_pickle(package, resource, map_location=None)

从包中反序列化资源,加载构造对象所需的所有模块

使用 import_module()

参数

  • package (str) – 模块包的名称(例如 "my_package.my_subpackage")。
  • resource (str) – 资源的唯一名称。
  • map_location – 传递给 torch.load 以确定张量如何映射到设备。默认为 None

返回

反序列化后的对象。

返回类型:任意


load_text(package, resource, encoding='utf-8', errors='strict')

加载字符串。

参数

  • package (str) – 模块包的名称(例如 "my_package.my_subpackage")。
  • resource (str) – 资源的唯一名称。
  • encoding (str, 可选) – 传递给 decode。默认为 'utf-8'
  • errors (str, 可选) – 传递给 decode。默认为 'strict'

返回值:加载的文本。

返回类型:str


python_version()

返回用于创建此软件包的 Python 版本。

注意:此功能为实验性质,不具备向前兼容性。计划后续将其迁移至锁文件中。

返回值:Optional[str] 返回 Python 版本号(例如 3.8.9),如果该软件包未存储版本信息则返回 None


class torch.package.Directory(name, is_dir)

一种文件结构表示形式。它以目录节点形式组织,每个节点包含其子目录列表。通过调用 PackageImporter.file_structure() 可为包创建目录结构。


has_file(filename)

检查文件是否存在于 Directory 中。

参数

  • filename (str) - 要搜索的文件路径。

返回值

如果 Directory 包含指定文件则返回 True。

返回类型:bool



torch.profiler


概述

PyTorch Profiler 是一款用于在训练和推理过程中收集性能指标的工具。通过其上下文管理器 API,开发者可以深入分析模型中最耗时的算子、检查输入张量形状和调用堆栈、研究设备内核活动,并可视化执行轨迹。


注意: torch.autograd 模块中的旧版 API 已被视为遗留接口,未来将被弃用。


API 参考


class torch.profiler._KinetoProfile(*, activities=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, custom_trace_id_callback=None)

底层分析器封装了自动梯度分析功能

参数

  • activities (iterable) – 用于分析的活跃组列表(CPU、CUDA),支持以下值:

torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.XPU.

默认值:ProfilerActivity.CPU 和(可用时)ProfilerActivity.CUDA 或(可用时)ProfilerActivity.XPU。

  • record_shapes ([bool]) – 保存算子输入形状信息。
  • profile_memory ([bool]) – 跟踪张量内存分配/释放(详见export_memory_timeline)。
  • with_stack ([bool]) – 记录算子的源码信息(文件和行号)。
  • with_flops ([bool]) – 使用公式估算特定算子(矩阵乘法和2D卷积)的浮点运算次数。
  • with_modules ([bool]) – 记录与算子调用栈对应的模块层次结构(包括函数名)。

例如:如果模块A的forward调用模块B的forward(其中包含aten::add算子),那么aten::add的模块层次结构就是A.B

注意:目前该功能仅支持TorchScript模型,不支持eager模式模型。

  • experimental_config (_ExperimentalConfig) – 一组实验性选项,供Kineto等分析器库使用。注意不保证向后兼容性。
  • execution_trace_observer (ExecutionTraceObserver) – PyTorch执行轨迹观察器对象。

PyTorch执行轨迹提供了基于图的AI/ML工作负载表示,支持回放基准测试、模拟器和仿真器。

当包含此参数时,观察器的start()和stop()方法将与PyTorch分析器在同一时间窗口被调用。

  • acc_events ([bool]) – 启用跨多个分析周期的FunctionEvents累积功能

注意:此API为实验性质,未来可能变更。

启用形状和调用栈跟踪会导致额外开销。

当指定record_shapes=True时,分析器会临时持有张量的引用,这可能会阻止某些依赖引用计数的优化,并引入额外的张量拷贝。


add_metadata(key, value)

向跟踪文件中添加用户定义的元数据,包含字符串键和字符串值


add_metadata_json(key, value)

向跟踪文件中添加用户自定义的元数据,包含字符串键和有效的JSON值


events()

返回未聚合的性能分析事件列表,可用于跟踪回调或在性能分析结束后使用


export_chrome_trace(path)

以 Chrome JSON 格式导出收集的跟踪数据。如果启用了 kineto,则仅导出调度中的最后一个周期。


export_memory_timeline(path, device=None)

从分析器收集的内存事件信息树中导出指定设备的数据,并生成时间线图表。使用export_memory_timeline可导出三种文件格式,通过path参数的后缀控制:

  • 如需生成HTML兼容的图表,使用.html后缀。内存时间线图表将以PNG格式嵌入HTML文件中。
  • 如需导出由[时间戳, [按类别划分的内存大小]]组成的数据点(其中times是时间戳,sizes是每个类别的内存使用量),根据后缀选择保存为JSON(.json)或gzip压缩的JSON(.json.gz)。
  • 如需原始内存事件数据,使用.raw.json.gz后缀。每个原始内存事件将包含(时间戳, 操作类型, 字节数, 类别),其中:
    • action[PREEXISTING, CREATE, INCREMENT_VERSION, DESTROY]之一
    • category来自torch.profiler._memory_profiler.Category枚举

输出:内存时间线数据将以gzip压缩JSON、JSON或HTML格式写入。


export_stacks(path, metric='self_cpu_time_total')

将堆栈跟踪保存到文件

参数

  • path (str) - 将堆栈文件保存到此路径;
  • metric (str) - 使用的指标:“self_cpu_time_total” 或 “self_cuda_time_total”

key_averages(group_by_input_shape=False, group_by_stack_n=0, group_by_overload_name=False) 

Averages events, grouping them by operator name and (optionally) input shapes, stack and overload name.


Note: To use shape/stack functionality make sure to set record_shapes/with_stack
when creating profiler context manager.


preset_metadata_json(key, value)

在性能分析器未启动时预设用户自定义元数据,该元数据后续会被添加到跟踪文件中。

元数据格式为字符串键与有效JSON值的组合


toggle_collection_dynamic(enable, activities)

功能说明

可在收集过程中的任意时间点开启/关闭活动收集功能。当前支持切换 Torch 算子(CPU)以及 Kineto 中支持的 CUDA 活动。

参数说明

  • activities (iterable) – 用于性能分析的活动组列表,支持以下取值:
    • torch.profiler.ProfilerActivity.CPU
    • torch.profiler.ProfilerActivity.CUDA

使用示例


with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]
) as p:code_to_profile_0()// turn off collection of all CUDA activityp.toggle_collection_dynamic(False, [torch.profiler.ProfilerActivity.CUDA])code_to_profile_1()// turn on collection of all CUDA activityp.toggle_collection_dynamic(True, [torch.profiler.ProfilerActivity.CUDA])code_to_profile_2()
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

class torch.profiler.profile(*, activities=None, schedule=None, on_trace_ready=None, record_shapes=False, profile_memory=False, with_stack=False, with_flops=False, with_modules=False, experimental_config=None, execution_trace_observer=None, acc_events=False, use_cuda=None, custom_trace_id_callback=None)

分析器上下文管理器。

参数

  • activities (iterable) – 用于分析的活跃组列表(CPU、CUDA),支持以下值:

torch.profiler.ProfilerActivity.CPUtorch.profiler.ProfilerActivity.CUDAtorch.profiler.ProfilerActivity.XPU

默认值:ProfilerActivity.CPU 和(如果可用)ProfilerActivity.CUDA 或(如果可用)ProfilerActivity.XPU。

  • schedule (Callable) – 可调用对象,接收步骤(int)作为单一参数,并返回

ProfilerAction 值,指定在每一步执行的分析器操作。

  • on_trace_ready (Callable) – 当 schedule 在分析过程中返回 ProfilerAction.RECORD_AND_SAVE 时,每一步调用的可调用对象。
  • record_shapes ([bool]) – 保存操作符输入形状的信息。
  • profile_memory ([bool]) – 跟踪张量内存分配/释放。
  • with_stack ([bool]) – 记录操作符的源信息(文件和行号)。
  • with_flops ([bool]) – 使用公式估算特定操作符(矩阵乘法和2D卷积)的浮点运算次数(FLOPs)。
  • with_modules ([bool]) – 记录与操作符调用堆栈对应的模块层次结构(包括函数名称)。例如,如果模块A的前向调用模块B的前向,其中包含一个aten::add操作符,那么aten::add的模块层次结构是A.B。

注意,目前此功能仅支持TorchScript模型,不支持eager模式模型。

  • experimental_config (_ExperimentalConfig) – 一组用于Kineto库功能的实验性选项。注意,不保证向后兼容性。
  • execution_trace_observer (ExecutionTraceObserver) – 一个PyTorch执行跟踪观察器对象。

PyTorch执行跟踪 提供基于图的AI/ML工作负载表示,并支持重放基准测试、模拟器和仿真器。

当包含此参数时,观察器的start()和stop()将在与PyTorch分析器相同的时间窗口内调用。请参阅下面的示例部分获取代码示例。

  • acc_events ([bool]) – 启用跨多个分析周期的FunctionEvents累积。
  • use_cuda ([bool])

自版本1.8.1起弃用:改用 activities

注意:使用 schedule() 生成可调用的计划。

非默认计划在分析长时间训练作业时非常有用,允许用户在训练过程的不同迭代中获取多个跟踪。

默认计划只是在上下文管理器持续时间内连续记录所有事件。

注意:使用 tensorboard_trace_handler() 生成TensorBoard的结果文件:

on_trace_ready=torch.profiler.tensorboard_trace_handler(dir_name)

分析完成后,结果文件可以在指定目录中找到。使用命令:

tensorboard --logdir dir_name

在TensorBoard中查看结果。

更多信息,请参阅 PyTorch Profiler TensorBoard Plugin

注意:启用形状和堆栈跟踪会导致额外的开销。

当指定record_shapes=True时,分析器将临时持有对张量的引用;这可能会进一步阻止某些依赖于引用计数的优化,并引入额外的张量副本。


示例

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,  torch.profiler.ProfilerActivity.CUDA, ]
) as p:code_to_profile()
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

使用性能分析器的 scheduleon_trace_readystep 函数:

# Non-default profiler schedule allows user to turn profiler on and off
# on different iterations of the training loop;
# trace_handler is called every time a new trace becomes available
def trace_handler(prof):print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))# prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step_num) + ".json")with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,  torch.profiler.ProfilerActivity.CUDA, ], # In this example with wait=1, warmup=1, active=2, repeat=1, # profiler will skip the first step/iteration, # start warming up on the second, record# the third and the forth iterations, # after which the trace will become available# and on_trace_ready (when set) is called;# the cycle repeats starting with the next stepschedule=torch.profiler.schedule(wait=1,  warmup=1,  active=2,  repeat=1), on_trace_ready=trace_handler# on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')# used when outputting for tensorboard) as p:for iter in range(N):code_iteration_to_profile(iter)# send a signal to the profiler that the next iteration has startedp.step()

以下示例展示了如何设置执行跟踪观察器(execution_trace_observer)


with torch.profiler.profile(...execution_trace_observer=(ExecutionTraceObserver().register_callback("./execution_trace.json")), ) as p:for iter in range(N):code_iteration_to_profile(iter)p.step()

你也可以参考 tests/profiler/test_profiler.py 中的 test_execution_trace_with_kineto() 方法。

注意:任何实现了 _ITraceObserver 接口的对象都可以传入使用。


get_trace_id()

返回当前跟踪ID。


set_custom_trace_id_callback(callback)

设置当生成新跟踪ID时要调用的回调函数。


step()

通知性能分析器下一个分析步骤已开始。


class torch.profiler.ProfilerAction(value)

可在指定时间间隔执行的性能分析器操作


class torch.profiler.ProfilerActivity 

Members:

CPU

XPU

MTIA

CUDA

HPU

PrivateUse1


property name 

torch.profiler.schedule(*, wait, warmup, active, repeat=0, skip_first=0, skip_first_wait=0)

返回一个可调用对象,可作为分析器的schedule参数使用。该分析器会跳过前skip_first个步骤,然后等待wait个步骤,接着进行warmup个步骤的热身,随后执行active个步骤的活跃记录,最后从wait步骤开始重复这个循环。

通过repeat参数可指定可选的循环次数,值为零表示循环将持续到分析结束。

skip_first_wait参数控制是否跳过第一个wait阶段。

当用户希望在循环之间等待时间超过skip_first但首次分析不适用时,这个功能很有用。例如,若skip_first为10且wait为20,当skip_first_wait为零时,第一个循环将在热身前等待10+20=30个步骤;若skip_first_wait非零,则仅等待10个步骤。之后所有循环都会在最后一次活跃记录和热身之间等待20个步骤。

返回类型:Callable


torch.profiler.tensorboard_trace_handler(dir_name, worker_name=None, use_gzip=False)

将跟踪文件输出到 dir_name 目录,该目录可直接作为日志目录传递给 TensorBoard。

在分布式场景中,每个工作节点的 worker_name 应保持唯一,默认会设置为 ‘[hostname]_[pid]’。


Intel 插桩与追踪技术 API


torch.profiler.itt.is_available()

检查 ITT 功能是否可用


torch.profiler.itt.mark(msg)

描述某个时间点发生的瞬时事件。

参数

  • msg (str) – 与该事件关联的ASCII消息。

torch.profiler.itt.range_push(msg)

将一段范围压入嵌套范围跨度栈。返回所启动范围的从零开始的深度。

参数

  • msg (str) – 与该范围关联的ASCII消息

torch.profiler.itt.range_pop()

从嵌套范围跨度栈中弹出一个范围。返回被结束范围的从零开始的深度。



torch.nn.init


警告:本模块中的所有函数均用于初始化神经网络参数,因此它们都在 torch.no_grad() 模式下运行,且不会被自动求导机制追踪。


torch.nn.init.calculate_gain(nonlinearity, param=None)

返回给定非线性函数的推荐增益值。

具体数值如下:

非线性函数增益值
Linear / Identity111
Conv{1,2,3}D111
Sigmoid111
Tanh53\frac{5}{3}35
ReLU2\sqrt{2}2
Leaky ReluKaTeX parse error: Expected 'EOF', got '_' at position 34: … \text{negative_̲slope}^2}}​​
SELU34\frac{3}{4}43

警告:为了实现自归一化神经网络,应该使用nonlinearity='linear'而非nonlinearity='selu'。这样可以使初始权重具有1/N的方差,这对在前向传播中形成稳定不动点是必要的。

相比之下,SELU的默认增益值牺牲了归一化效果,以在矩形层中获得更稳定的梯度流。

参数说明

  • nonlinearity - 非线性函数(nn.functional名称)
  • param - 非线性函数的可选参数

使用示例


>>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2

torch.nn.init.uniform_(tensor, a=0.0, b=1.0, generator=None)

使用均匀分布中的值填充输入张量。

U(a,b){U}(a, b)U(a,b).

参数

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • a (float) – 均匀分布的下界
  • b (float) – 均匀分布的上界
  • generator (Optional[ Generator ]) – 用于采样的torch生成器(默认值:None)

返回类型:Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)

torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)

使用正态分布中的随机值填充输入张量。

N(mean,std2)\mathcal{N}(\text{mean}, \text{std}^2)N(mean,std2)

参数说明

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • mean (float) – 正态分布的均值
  • std (float) – 正态分布的标准差
  • generator (Optional[ Generator ]) – 用于采样的torch生成器对象(默认值为None)

返回值类型:Tensor

使用示例


>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)

torch.nn.init.constant_(tensor, val)

将输入张量填充为值 val\text{val}val。

参数

  • tensor ( Tensor ) – 一个 n 维的 torch.Tensor
  • val (float) – 用于填充张量的值

返回类型 : Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)

torch.nn.init.ones_(tensor)

填充输入张量(Tensor)为标量值1。

参数

  • tensor ( Tensor ) – 一个n维的torch.Tensor

返回类型:Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.ones_(w)

torch.nn.init.zeros_(tensor)

将输入张量填充为标量值0。

参数

  • tensor ( Tensor ) - 一个n维的torch.Tensor

返回类型: Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.zeros_(w)

torch.nn.init.eye_(tensor)

将二维输入张量填充为单位矩阵。

在Linear层中保留输入的恒等性,尽可能多地保留输入特征。

参数

  • tensor – 一个二维的torch.Tensor张量

示例

>>> w = torch.empty(3, 5)
>>> nn.init.eye_(w)

torch.nn.init.dirac_(tensor, groups=1)

用狄拉克δ函数填充{3, 4, 5}维输入张量。

在卷积层中保持输入数据的特性,尽可能多地保留输入通道。当组数大于1时,每组通道都会保持其特性。

参数

  • tensor – 一个{3, 4, 5}维的torch.Tensor
  • groups (int, 可选) – 卷积层中的组数(默认值:1)

示例

>>> w = torch.empty(3, 16, 5, 5)
>>> nn.init.dirac_(w)
>>> w = torch.empty(3, 24, 5, 5)
>>> nn.init.dirac_(w, 3)

torch.nn.init.xavier_uniform_(tensor, gain=1.0, generator=None)

使用Xavier均匀分布为输入张量填充数值。

该方法在论文《理解深度前馈神经网络训练难点》- Glorot, X. 和 Bengio, Y. (2010)中有详细描述。生成的张量将从U(−a,a)分布中采样,其中:

a = gain × √(6/(fan_in + fan_out))

该初始化方法也被称为Glorot初始化。

参数:

  • tensor (Tensor) - 一个n维的torch.Tensor
  • gain (float) - 可选的缩放因子
  • generator (Optional[Generator]) - 用于采样的torch生成器(默认值:None)

返回类型:Tensor


示例:

>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

注意:fan_infan_out 的计算基于权重矩阵以转置方式使用的假设(例如 Linear 层中的 x @ w.T,其中 w.shape = [fan_out, fan_in])。

这对正确初始化至关重要。

如果计划使用 x @ w(其中 w.shape = [fan_in, fan_out]),请传入转置后的权重矩阵,即 nn.init.xavier_uniform_(w.T, ...)


torch.nn.init.xavier_normal_(tensor, gain=1.0, generator=None)

使用Xavier正态分布为输入张量填充数值。

该方法在论文《理解深度前馈神经网络训练的难点》- Glorot, X. 和 Bengio, Y. (2010) 中提出。生成的张量将从U(−a,a)\mathcal{U}(-a, a)U(a,a)分布中采样数值,其中

KaTeX parse error: Expected 'EOF', got '_' at position 48: …ac{6}{\text{fan_̲in} + \text{fan…

参数

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • gain (float) – 可选的缩放因子
  • generator (Optional[ Generator ]) – 用于采样的torch生成器(默认值:None)

返回类型:Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.xavier_normal_(w)

注意:fan_infan_out 的计算基于权重矩阵以转置方式使用的假设(例如 Linear 层中的 x @ w.T,其中 w.shape = [fan_out, fan_in])。

这一点对于正确初始化至关重要。

如果你计划使用 x @ w(其中 w.shape = [fan_in, fan_out]),请传入转置后的权重矩阵,即 nn.init.xavier_normal_(w.T, ...)


torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)

使用Kaiming均匀分布为输入张量填充数值。

该方法出自论文《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》——何恺明等人(2015年)。生成的张量数值将采样自U(−bound,bound)U(−bound,bound)U(bound,bound)区间,其中边界值计算公式为:

KaTeX parse error: Expected 'EOF', got '_' at position 59: …ac{3}{\text{fan_̲mode}}}

该初始化方法也被称为He初始化。

参数说明

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • a (float) – 该层后接整流器的负斜率(仅当使用'leaky_relu'时生效)
  • mode (str) – 可选'fan_in'(默认)或'fan_out'。选择'fan_in'可保持前向传播中权重的方差量级,选择'fan_out'则保持反向传播中的量级
  • nonlinearity (str) – 非线性函数名称(需对应nn.functional中的名称),建议仅使用'relu''leaky_relu'(默认值)
  • generator (Optional[ Generator ]) – 用于采样的torch生成器对象(默认值为None)

使用示例


>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

注意:fan_infan_out 的计算基于权重矩阵以转置方式使用的假设(例如在 Linear 层中的 x @ w.T,其中 w.shape = [fan_out, fan_in])。

这一点对于正确的初始化非常重要。

如果你计划使用 x @ w(其中 w.shape = [fan_in, fan_out]),请传入转置后的权重矩阵,即 nn.init.kaiming_uniform_(w.T, ...)


torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None)

使用Kaiming正态分布为输入张量填充数值。

该方法在论文《深入研究整流器:超越ImageNet分类的人类水平性能》- He, K. 等人 (2015) 中有详细描述。生成的张量值将从 N(0,std2)N(0,std^2)N(0,std2) 分布中采样,其中

KaTeX parse error: Expected 'EOF', got '_' at position 48: …\sqrt{\text{fan_̲mode}}}

该方法也被称为He初始化。

参数

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • a (float) – 该层之后使用的整流器的负斜率(仅与'leaky_relu'一起使用)
  • mode (str) – 可选'fan_in'(默认)或'fan_out'。选择'fan_in'会保持前向传播中权重的方差大小,选择'fan_out'则会保持反向传播中的大小。
  • nonlinearity (str) – 非线性函数(nn.functional名称),建议仅与'relu''leaky_relu'(默认)一起使用。
  • generator (Optional[ Generator ]) – 用于采样的torch生成器(默认:None)

示例

>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

注意:fan_infan_out 的计算基于权重矩阵以转置方式使用的假设(例如 Linear 层中的 x @ w.T,其中 w.shape = [fan_out, fan_in])。

这对正确初始化至关重要。

如果计划使用 x @ w(其中 w.shape = [fan_in, fan_out]),请传入转置后的权重矩阵,即 nn.init.kaiming_normal_(w.T, ...)


torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0, generator=None)

使用截断正态分布生成的值填充输入张量。

这些值实际上是从正态分布 N(mean,std2)N(mean,std^2)N(mean,std2) 中抽取的,对于超出 [a,b][a,b][a,b] 范围的值会重新抽取,直到其落在边界内。当满足 a≤mean≤ba \leq mean \leq bameanb 时,生成随机值的方法效果最佳。

参数

  • tensor ( Tensor ) – 一个n维的torch.Tensor
  • mean (float) – 正态分布的均值
  • std (float) – 正态分布的标准差
  • a (float) – 最小截断值
  • b (float) – 最大截断值
  • generator (Optional[ Generator ]) – 用于采样的torch生成器(默认值:None)

返回类型:Tensor


示例

>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)

torch.nn.init.orthogonal_(tensor, gain=1, generator=None)

用(半)正交矩阵填充输入张量。

该方法在《Exact solutions to the nonlinear dynamics of learning in deep linear neural networks》- Saxe, A. 等人 (2013) 的论文中有所描述。输入张量必须至少具有2个维度,对于超过2维的张量,其尾部维度会被展平。

参数

  • tensor – 一个n维的 torch.Tensor,其中 $n \geq 2 $
  • gain – 可选缩放因子
  • generator (Optional[ Generator ]) – 用于采样的 torch Generator(默认值:None)

示例

>>> w = torch.empty(3, 5)
>>> nn.init.orthogonal_(w)

torch.nn.init.sparse_(tensor, sparsity, std=0.01, generator=None)

将二维输入张量填充为稀疏矩阵。

非零元素将从正态分布 N(0,0.01)N(0, 0.01)N(0,0.01) 中抽取,如《Deep learning via Hessian-free optimization》- Martens, J. (2010) 所述。

参数

  • tensor – 一个 n 维的 torch.Tensor
  • sparsity – 每列中要置零的元素比例
  • std – 用于生成非零值的正态分布的标准差
  • generator (Optional[ Generator ]) – 用于采样的 torch 生成器(默认值:None)

示例

>>> w = torch.empty(3, 5)
>>> nn.init.sparse_(w, sparsity=0.1)


torch.nn.attention

该模块包含用于改变 torch.nn.functional.scaled_dot_product_attention 行为的函数和类


工具集

sdpa_kernel用于选择缩放点积注意力计算后端的上下文管理器
SDPBackend枚举类,包含缩放点积注意力计算的不同后端实现

子模块

flex_attention该模块实现了PyTorch中flex_attention的用户接口API。
bias定义了与scaled_dot_product_attention配合使用的偏置子类
experimental

2025-08-20(三)

http://www.xdnf.cn/news/18365.html

相关文章:

  • 372. 超级次方
  • IIS访问报错:HTTP 错误 500.19 - Internal Server Error
  • Spring Retry实战指南_让你的应用更具韧性
  • 区块链技术:重塑未来互联网的伟大动力
  • Python Day32 JavaScript 数组与对象核心知识点整理
  • 源码编译部署 LAMP 架构详细步骤说明
  • Java设计模式-命令模式
  • python的校园顺路代送系统
  • Day 40:训练和测试的规范写法
  • Flink实现Exactly-Once语义的完整技术分解
  • 利用无事务方式插入数据库解决并发插入问题(最小主键id思路)
  • idea进阶技能掌握, 自带HTTP测试工具HTTP client使用方法详解,完全可替代PostMan
  • 暖哇科技AI调查智能体上线,引领保险调查风控智能化升级
  • 【数据结构】排序算法全解析:概念与接口
  • RK android14 Setting一级菜单IR遥控器无法聚焦问题解决方法
  • Apache ShenYu和Nacos之间的通信原理
  • VPS海外节点性能监控全攻略:从基础配置到高级优化
  • Android 入门到实战(三):ViewPager及ViewPager2多页面布局
  • 数据预处理学习心得:从理论到实践的桥梁搭建
  • 比剪映更轻量!SolveigMM 视频无损剪切实战体验
  • 29.Linux rsync+inotify解决同步数据实时性
  • 3D检测笔记:相机模型与坐标变换
  • 详解 scikit-learn 数据预处理工具:从理论到实践
  • CS+ for CC编译超慢的问题该如何解决
  • Day23 双向链表
  • 计算机网络--HTTP协议
  • 亚马逊新品爆单策略:从传统困境到智能突破
  • 【Grafana】grafana-image-renderer配合python脚本实现仪表盘导出pdf
  • 给你的Unity编辑器添加实现类似 Odin 的 条件显示字段 (ShowIf/HideIf) 功能
  • word——如何给封面、目录、摘要、正文设置不同的页码