当前位置: 首页 > news >正文

PyTorch API 1 - 概述、数学运算、nn、实用工具、函数、张量

文章目录

  • torch
    • 张量
      • 创建操作
      • 索引、切片、连接与变异操作
    • 加速器
    • 生成器
    • 随机采样
      • 原地随机采样
      • 准随机采样
    • 序列化
    • 并行计算
    • 局部禁用梯度计算
    • 数学运算
      • 常量
      • 逐点运算
      • 归约操作
      • 比较运算
      • 频谱操作
      • 其他操作
      • BLAS 和 LAPACK 运算
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
      • 遍历操作
    • 实用工具
    • 符号数字
    • 导出路径
    • 控制流
    • 优化方法
    • 操作符标签
  • torch.nn
    • 容器模块
    • 模块全局钩子
    • 卷积层
    • 池化层
    • 填充层
    • 非线性激活函数(加权求和与非线性变换)
    • 非线性激活函数(其他)
    • 归一化层
    • 循环神经网络层
    • Transformer 层
    • 线性层
    • Dropout 层
    • 稀疏层
    • 距离函数
    • 损失函数
    • 视觉层
    • 通道混洗层
    • 数据并行层(多GPU,分布式)
    • 实用工具
      • 参数梯度裁剪工具
      • 模块参数扁平化与反扁平化工具
      • 模块与批归一化融合工具
      • 模块参数内存格式转换工具
      • 权重归一化应用与移除工具
      • 模块参数初始化工具
      • 模块参数剪枝工具类与函数
      • 使用`torch.nn.utils.parameterize.register_parametrization()`新参数化功能实现的参数化
      • 为现有模块上的张量参数化的实用函数
      • 以无状态方式调用给定模块的实用函数
      • 其他模块中的实用函数
    • 量化函数
    • 惰性模块初始化
      • 别名
  • torch.nn.functional
    • 卷积函数
    • 池化函数
    • 注意力机制
    • 非线性激活函数
    • 线性函数
    • Dropout 函数
    • 稀疏函数
    • 距离函数
    • 损失函数
    • 视觉函数
    • DataParallel 功能(多GPU,分布式)
      • data_parallel
  • torch.Tensor
    • 数据类型
    • 初始化与基础操作
    • Tensor 类参考
  • 张量属性
    • torch.dtype
    • torch.device
    • torch.layout
    • torch.memory_format
  • 张量视图


torch

torch 包提供了多维张量的数据结构,并定义了针对这些张量的数学运算。此外,它还包含许多实用工具,可实现张量及任意类型的高效序列化,以及其他实用功能。

该包还配有 CUDA 版本,支持在计算能力 >= 3.0 的 NVIDIA GPU 上运行张量计算。


张量

is_tensor如果 obj 是 PyTorch 张量则返回 True。
is_storage如果 obj 是 PyTorch 存储对象则返回 True。
is_complex如果 input 的数据类型是复数类型(即 torch.complex64torch.complex128)则返回 True。
is_conj如果 input 是共轭张量(即其共轭位被设为 True)则返回 True。
is_floating_point如果 input 的数据类型是浮点类型(即 torch.float64torch.float32torch.float16torch.bfloat16)则返回 True。
is_nonzero如果 input 是单元素张量且在类型转换后不等于零则返回 True。
set_default_dtype将默认浮点数据类型设置为 d
get_default_dtype获取当前默认浮点 torch.dtype
set_default_device设置默认 torch.Tensor 分配到的设备为 device
get_default_device获取默认 torch.Tensor 分配到的设备 device
set_default_tensor_type
numel返回 input 张量中的元素总数。
set_printoptions设置打印选项。
set_flush_denormal在 CPU 上禁用非正规浮点数。

创建操作


随机采样创建操作列在随机采样下,包括:torch.rand()torch.rand_like()torch.randn()torch.randn_like()torch.randint()torch.randint_like()torch.randperm()。你也可以使用torch.empty()结合原地随机采样方法来创建torch.Tensor,其值从更广泛的分布中采样。

tensor通过复制data构造一个没有自动求导历史的张量(也称为“叶子张量”,参见自动求导机制)。
sparse_coo_tensor在给定的indices处构造一个COO(坐标)格式的稀疏张量。
sparse_csr_tensor在给定的crow_indicescol_indices处构造一个CSR(压缩稀疏行)格式的稀疏张量。
sparse_csc_tensor在给定的ccol_indicesrow_indices处构造一个CSC(压缩稀疏列)格式的稀疏张量。
sparse_bsr_tensor在给定的crow_indicescol_indices处构造一个BSR(块压缩稀疏行)格式的稀疏张量,包含指定的二维块。
sparse_bsc_tensor在给定的ccol_indicesrow_indices处构造一个BSC(块压缩稀疏列)格式的稀疏张量,包含指定的二维块。
asarrayobj转换为张量。
as_tensordata转换为张量,尽可能共享数据并保留自动求导历史。
as_strided创建一个现有torch.Tensor的视图,指定sizestridestorage_offset
from_file创建一个由内存映射文件支持的CPU张量。
from_numpynumpy.ndarray创建Tensor
from_dlpack将外部库中的张量转换为torch.Tensor
frombuffer从实现Python缓冲区协议的对象创建一维Tensor
zeros返回一个填充标量值0的张量,形状由可变参数size定义。
zeros_like返回一个填充标量值0的张量,大小与input相同。
ones返回一个填充标量值1的张量,形状由可变参数size定义。
ones_like返回一个填充标量值1的张量,大小与input相同。
arange返回一个大小为⌈end−startstep⌉的一维张量,值取自区间[start, end),步长为step
range返回一个大小为⌊end−startstep⌋+1的一维张量,值从startend,步长为step
linspace创建一个大小为steps的一维张量,值从startend均匀分布(包含端点)。
logspace创建一个大小为steps的一维张量,值在base^startbase^end之间均匀分布(包含端点),对数刻度。
eye返回一个二维张量,对角线为1,其余为0。
empty返回一个填充未初始化数据的张量。
empty_like返回一个未初始化的张量,大小与input相同。
empty_strided创建一个指定sizestride的张量,填充未定义数据。
full创建一个大小为size的张量,填充fill_value
full_like返回一个大小与input相同的张量,填充fill_value
quantize_per_tensor将浮点张量转换为具有给定比例和零点的量化张量。
quantize_per_channel将浮点张量转换为具有给定比例和零点的逐通道量化张量。
dequantize通过反量化量化张量返回一个fp32张量。
complex构造一个复数张量,实部为real,虚部为imag
polar构造一个复数张量,其元素为极坐标对应的笛卡尔坐标,绝对值为abs,角度为angle
heaviside计算input中每个元素的Heaviside阶跃函数。

索引、切片、连接与变异操作

adjoint返回张量的共轭视图,并转置最后两个维度。
argwhere返回包含输入张量input所有非零元素索引的张量。
cat在给定维度上连接张量序列tensors
concattorch.cat()的别名。
concatenatetorch.cat()的别名。
conj返回翻转共轭位后的输入张量input视图。
chunk尝试将张量分割为指定数量的块。
dsplitindices_or_sections深度方向分割三维及以上张量input
column_stack通过水平堆叠tensors中的张量创建新张量。
dstack沿第三轴深度堆叠张量序列。
gather沿指定维度dim聚集值。
hsplitindices_or_sections水平分割一维及以上张量input
hstack水平(按列)堆叠张量序列。
index_add功能描述见index_add_()
index_copy功能描述见index_add_()
index_reduce功能描述见index_reduce_()
index_select使用长整型张量index沿维度dim索引输入张量input,返回新张量。
masked_select根据布尔掩码mask(BoolTensor类型)索引输入张量input,返回新的一维张量。
movedim将输入张量input的维度从source位置移动到destination位置。
moveaxistorch.movedim()的别名。
narrow返回输入张量input的缩小版本。
narrow_copy功能同Tensor.narrow(),但返回副本而非共享存储。
nonzero
permute返回原始张量input的维度重排视图。
reshape返回与input数据相同但形状改变的新张量。
row_stacktorch.vstack()的别名。
select在选定维度的给定索引处切片输入张量input
scattertorch.Tensor.scatter_()的非原位版本。
diagonal_scatter将源张量src的值沿dim1dim2嵌入到输入张量input的对角元素中。
select_scatter将源张量src的值嵌入到输入张量input的指定索引处。
slice_scatter将源张量src的值沿指定维度嵌入到输入张量input中。
scatter_addtorch.Tensor.scatter_add_()的非原位版本。
scatter_reducetorch.Tensor.scatter_reduce_()的非原位版本。
split将张量分割成块。
squeeze移除输入张量input中所有大小为1的指定维度。
stack沿新维度连接张量序列。
swapaxestorch.transpose()的别名。
swapdimstorch.transpose()的别名。
t要求输入input为≤2维张量,并转置维度0和1。
take返回输入张量input在给定索引处元素组成的新张量。
take_along_dim沿指定维度dim从一维索引indices处选择输入张量input的值。
tensor_split根据indices_or_sections将张量沿维度dim分割为多个子张量(均为input的视图)。
tile通过重复输入张量input的元素构造新张量。
transpose返回输入张量input的转置版本。
unbind移除张量维度。
unravel_index将平面索引张量转换为坐标张量元组,用于索引指定形状的任意张量。
unsqueeze在指定位置插入大小为1的维度,返回新张量。
vsplitindices_or_sections垂直分割二维及以上张量input
vstack垂直(按行)堆叠张量序列。
where根据条件conditioninputother中选择元素组成新张量。

加速器

在 PyTorch 代码库中,我们将"加速器"定义为与 CPU 协同工作以加速计算的 torch.device。这些设备采用异步执行方案,使用 torch.Streamtorch.Event 作为主要的同步机制。我们假设在给定主机上一次只能使用一个这样的加速器,这使得我们可以将当前加速器作为默认设备,用于固定内存、流设备类型、FSDP 等相关概念。

目前支持的加速器设备包括(无特定顺序):“CUDA”、“MTIA”、“XPU”、“MPS”、“HPU”以及 PrivateUse1(许多设备不在 PyTorch 代码库本身中)。

PyTorch 生态系统中的许多工具使用 fork 创建子进程(例如数据加载或操作内并行),因此应尽可能延迟任何会阻止后续 fork 的操作。这一点尤为重要,因为大多数加速器的初始化都会产生这种影响。实际应用中需注意,默认情况下检查 torch.accelerator.current_accelerator() 是编译时检查,因此始终是 fork 安全的。相反,向此函数传递 check_available=True 标志或调用 torch.accelerator.is_available() 通常会阻止后续 fork。

某些后端提供实验性的可选选项,使运行时可用性检查成为 fork 安全的。例如,使用 CUDA 设备时可以使用 PYTORCH_NVML_BASED_CUDA_CHECK=1

Stream按先进先出(FIFO)顺序异步执行相应任务的有序队列。
Event查询和记录流状态,以识别或控制跨流的依赖关系并测量时间。

生成器

Generator创建并返回一个生成器对象,该对象用于管理产生伪随机数的算法状态。

随机采样

seed将所有设备的随机数生成种子设置为非确定性随机数
manual_seed设置所有设备的随机数生成种子
initial_seed返回生成随机数的初始种子(Python长整型)
get_rng_state返回随机数生成器状态(torch.ByteTensor类型)
set_rng_state设置随机数生成器状态

torch.default_generator 返回默认的CPU torch.Generator


bernoulli从伯努利分布中抽取二元随机数(0或1)。
multinomial返回一个张量,其中每一行包含从对应行的张量input中的多项式(更严格的定义是多元的,更多细节请参考torch.distributions.multinomial.Multinomial)概率分布中抽取的num_samples个索引。
normal返回一个张量,其中的随机数是从具有给定均值和标准差的独立正态分布中抽取的。
poisson返回一个与input大小相同的张量,其中每个元素都是从具有由input中对应元素给出的速率参数的泊松分布中抽取的。
rand返回一个张量,其中填充了来自区间[0,1)上的均匀分布的随机数。
rand_like返回一个与input大小相同的张量,其中填充了来自区间[0,1)上的均匀分布的随机数。
randint返回一个张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。
randint_like返回一个与张量input形状相同的张量,其中填充了在low(包含)和high(不包含)之间均匀生成的随机整数。
randn返回一个张量,其中填充了来自均值为0、方差为1的正态分布的随机数(也称为标准正态分布)。
randn_like返回一个与input大小相同的张量,其中填充了来自均值为0、方差为1的正态分布的随机数。
randperm返回一个从0n - 1的整数的随机排列。

原地随机采样

张量上还定义了一些原地随机采样函数。点击以下链接查看它们的文档:

  • torch.Tensor.bernoulli_() - torch.bernoulli() 的原地版本
  • torch.Tensor.cauchy_() - 从柯西分布中抽取的数字
  • torch.Tensor.exponential_() - 从指数分布中抽取的数字
  • torch.Tensor.geometric_() - 从几何分布中抽取的元素
  • torch.Tensor.log_normal_() - 从对数正态分布中抽取的样本
  • torch.Tensor.normal_() - torch.normal() 的原地版本
  • torch.Tensor.random_() - 从离散均匀分布中抽取的数字
  • torch.Tensor.uniform_() - 从连续均匀分布中抽取的数字

准随机采样

quasirandom.SobolEnginetorch.quasirandom.SobolEngine 是一个用于生成(加扰)Sobol序列的引擎。

序列化

save将对象保存到磁盘文件。
load从文件中加载由 torch.save() 保存的对象。

并行计算

get_num_threads返回用于CPU操作并行化的线程数
set_num_threads设置CPU上用于内部操作并行化的线程数
get_num_interop_threads返回CPU上用于操作间并行化的线程数(例如
set_num_interop_threads设置用于操作间并行化的线程数(例如

局部禁用梯度计算

上下文管理器 torch.no_grad()torch.enable_grad()torch.set_grad_enabled() 可用于在局部范围内禁用或启用梯度计算。具体用法详见局部禁用梯度计算文档。这些上下文管理器是线程局部的,因此如果通过threading模块等将工作发送到其他线程,它们将不会生效。


示例:

>>> x = torch.zeros(1, requires_grad=True)>>> with torch.no_grad():
...     y = x * 2
>>> y.requires_grad
False>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False>>> torch.set_grad_enabled(True)  # 也可以作为函数使用
>>> y = x * 2
>>> y.requires_grad
True>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

| no_grad | 禁用梯度计算的上下文管理器。 |

| — | — |

| enable_grad | 启用梯度计算的上下文管理器。 |

| autograd.grad_mode.set_grad_enabled | 设置梯度计算开启或关闭的上下文管理器。 |

| is_grad_enabled | 如果当前梯度模式已启用则返回True。 |

| autograd.grad_mode.inference_mode | 启用或禁用推理模式的上下文管理器。 |

| is_inference_mode_enabled | 如果当前推理模式已启用则返回True。 |


数学运算


常量

inf浮点正无穷大。math.inf 的别名。
nan浮点“非数字”值。该值不是一个合法的数字。math.nan 的别名。

逐点运算

abs计算 input 中每个元素的绝对值。
absolutetorch.abs() 的别名
acos计算 input 中每个元素的反余弦值。
arccostorch.acos() 的别名。
acosh返回一个新张量,包含 input 中元素的反双曲余弦值。
arccoshtorch.acosh() 的别名。
addother 乘以 alpha 后加到 input 上。
addcdivtensor1tensor2 进行逐元素除法,将结果乘以标量 value 后加到 input 上。
addcmultensor1tensor2 进行逐元素乘法,将结果乘以标量 value 后加到 input 上。
angle计算给定 input 张量中每个元素的角度(弧度)。
asin返回一个新张量,包含 input 中元素的反正弦值。
arcsintorch.asin() 的别名。
asinh返回一个新张量,包含 input 中元素的反双曲正弦值。
arcsinhtorch.asinh() 的别名。
atan返回一个新张量,包含 input 中元素的反正切值。
arctantorch.atan() 的别名。
atanh返回一个新张量,包含 input 中元素的反双曲正切值。
arctanhtorch.atanh() 的别名。
atan2逐元素计算 inputi/otheri\text{input}{i} / \text{other}{i}inputi​/otheri​ 的反正切值,并考虑象限。
arctan2torch.atan2() 的别名。
bitwise_not计算给定输入张量的按位取反。
bitwise_and计算 inputother 的按位与。
bitwise_or计算 inputother 的按位或。
bitwise_xor计算 inputother 的按位异或。
bitwise_left_shift计算 input 左移 other 位的算术结果。
bitwise_right_shift计算 input 右移 other 位的算术结果。
ceil返回一个新张量,包含 input 中每个元素的向上取整值(不小于该元素的最小整数)。
clampinput 中的所有元素限制在 minmax 范围内。
cliptorch.clamp() 的别名。
conj_physical计算给定 input 张量的逐元素共轭。
copysign创建一个新浮点张量,其数值为 input 的绝对值,符号为 other 的符号(逐元素)。
cos返回一个新张量,包含 input 中元素的余弦值。
cosh返回一个新张量,包含 input 中元素的双曲余弦值。
deg2rad返回一个新张量,将 input 中的角度从度转换为弧度。
divinput 中的每个元素除以 other 中对应的元素。
dividetorch.div() 的别名。
digammatorch.special.digamma() 的别名。
erftorch.special.erf() 的别名。
erfctorch.special.erfc() 的别名。
erfinvtorch.special.erfinv() 的别名。
exp返回一个新张量,包含输入张量 input 中元素的指数值。
exp2torch.special.exp2() 的别名。
expm1torch.special.expm1() 的别名。
fake_quantize_per_channel_affine返回一个新张量,其中 input 的数据按通道使用 scalezero_pointquant_minquant_max 进行伪量化,通道由 axis 指定。
fake_quantize_per_tensor_affine返回一个新张量,其中 input 的数据使用 scalezero_pointquant_minquant_max 进行伪量化。
fixtorch.trunc() 的别名
float_power以双精度逐元素计算 inputexponent 次幂。
floor返回一个新张量,包含 input 中每个元素的向下取整值(不大于该元素的最大整数)。
floor_divide
fmod逐元素应用 C++ 的 std::fmod。
frac计算 input 中每个元素的小数部分。
frexpinput 分解为尾数和指数张量,满足 input=mantissa×2exponent\text{input} = \text{mantissa} \times 2^{\text{exponent}}input=mantissa×2exponent。
gradient使用二阶精确中心差分法估计函数 g:Rn→Rg : \mathbb{R}^n \rightarrow \mathbb{R}g:Rn→R 在一维或多维上的梯度,边界处使用一阶或二阶估计。
imag返回一个新张量,包含 self 张量的虚部值。
ldexpinput 乘以 2 的 other 次方。
lerp根据标量或张量 weight 对两个张量 start(由 input 给出)和 end 进行线性插值,返回结果张量 out
lgamma计算 input 上伽马函数绝对值的自然对数。
log返回一个新张量,包含 input 中元素的自然对数。
log10返回一个新张量,包含 input 中元素的以 10 为底的对数。
log1p返回一个新张量,包含 (1 + input) 的自然对数。
[log2](https://docs.pytorch.org/docs/stable/generated/torch.log2.html#torch.log2 "

归约操作

argmax返回 input 张量中所有元素最大值的索引
argmin返回展平张量或沿指定维度中最小值的索引
amax返回 input 张量在给定维度 dim 上各切片的最大值
amin返回 input 张量在给定维度 dim 上各切片的最小值
aminmax计算 input 张量的最小值和最大值
all测试 input 中所有元素是否均为 True
any测试 input 中是否存在任意元素为 True
max返回 input 张量中所有元素的最大值
min返回 input 张量中所有元素的最小值
dist返回 (input - other) 的 p-范数
logsumexp返回 input 张量在给定维度 dim 上各行指数求和对数
mean
nanmean计算指定维度中所有非 NaN 元素的均值
median返回 input 中所有值的中位数
nanmedian返回 input 中所有值的中位数(忽略 NaN 值)
mode返回命名元组 (values, indices),其中 valuesinput 张量在给定维度 dim 上各行的众数值(即该行最常出现的值),indices 是各众数值的索引位置
norm返回给定张量的矩阵范数或向量范数
nansum返回所有元素的和(将 NaN 视为零)
prod返回 input 张量中所有元素的乘积
quantile计算 input 张量在维度 dim 上各行的 q 分位数
nanquantiletorch.quantile() 的变体,忽略 NaN 值计算分位数 q(如同 input 中的 NaN 不存在)
std计算指定维度 dim 上的标准差
std_mean计算指定维度 dim 上的标准差和均值
sum返回 input 张量中所有元素的和
unique返回输入张量的唯一元素
unique_consecutive去除连续等效元素组中除首个元素外的所有元素
var计算指定维度 dim 上的方差
var_mean计算指定维度 dim 上的方差和均值
count_nonzero统计张量 input 在给定维度 dim 上的非零值数量

比较运算

allclose检查 inputother 是否满足条件:
argsort返回按值升序排列张量沿指定维度的索引。
eq逐元素计算相等性
equal如果两个张量大小和元素相同返回 True,否则返回 False
ge逐元素计算 input≥other\text{input} \geq \text{other}input≥other。
greater_equaltorch.ge() 的别名。
gt逐元素计算 input>other\text{input} \text{other}input>other。
greatertorch.gt() 的别名。
isclose返回一个新张量,其布尔元素表示 input 的每个元素是否与 other 的对应元素"接近"。
isfinite返回一个新张量,其布尔元素表示每个元素是否为有限值。
isin测试 elements 的每个元素是否在 test_elements 中。
isinf测试 input 的每个元素是否为无穷大(正无穷或负无穷)。
isposinf测试 input 的每个元素是否为正无穷。
isneginf测试 input 的每个元素是否为负无穷。
isnan返回一个新张量,其布尔元素表示 input 的每个元素是否为 NaN。
isreal返回一个新张量,其布尔元素表示 input 的每个元素是否为实数值。
kthvalue返回一个命名元组 (values, indices),其中 valuesinput 张量在给定维度 dim 上每行的第 k 个最小元素。
le逐元素计算 input≤other\text{input} \leq \text{other}input≤other。
less_equaltorch.le() 的别名。
lt逐元素计算 input<other\text{input} < \text{other}input<other。
lesstorch.lt() 的别名。
maximum计算 inputother 的逐元素最大值。
minimum计算 inputother 的逐元素最小值。
fmax计算 inputother 的逐元素最大值。
fmin计算 inputother 的逐元素最小值。
ne逐元素计算 input≠other\text{input} \neq \text{other}input=other。
not_equaltorch.ne() 的别名。
sort按值升序排列 input 张量沿指定维度的元素。
topk返回 input 张量沿给定维度的前 k 个最大元素。
msort按值升序排列 input 张量沿其第一维度的元素。

频谱操作

stft短时傅里叶变换 (STFT)。
istft短时傅里叶逆变换。
bartlett_window巴特利特窗函数。
blackman_window布莱克曼窗函数。
hamming_window汉明窗函数。
hann_window汉恩窗函数。
kaiser_window计算具有窗口长度 window_length 和形状参数 beta 的凯撒窗。

其他操作

atleast_1d返回每个零维输入张量的一维视图。
atleast_2d返回每个零维输入张量的二维视图。
atleast_3d返回每个零维输入张量的三维视图。
bincount统计非负整数数组中每个值的出现频率。
block_diag根据提供的张量创建块对角矩阵。
broadcast_tensors按照广播语义广播给定张量。
broadcast_toinput广播至指定形状shape
broadcast_shapes功能类似broadcast_tensors(),但针对形状操作。
bucketize返回input中每个值所属的桶索引,桶边界由boundaries定义。
cartesian_prod计算给定张量序列的笛卡尔积。
cdist计算两组行向量之间批次化的p范数距离。
clone返回input的副本。
combinations计算给定张量中长度为rrr的组合。
corrcoef估计input矩阵的皮尔逊积矩相关系数矩阵,其中行代表变量,列代表观测值。
cov估计input矩阵的协方差矩阵,其中行代表变量,列代表观测值。
cross返回inputother在维度dim上的向量叉积。
cummax返回命名元组(values, indices),其中valuesinput在维度dim上的累积最大值。
cummin返回命名元组(values, indices),其中valuesinput在维度dim上的累积最小值。
cumprod返回input在维度dim上的累积乘积。
cumsum返回input在维度dim上的累积和。
diag* 若input为向量(1维张量),则返回2维方阵
diag_embed创建张量,其特定2D平面(由dim1dim2指定)的对角线由input填充。
diagflat* 若input为向量(1维张量),则返回2维方阵
diagonal返回input的部分视图,其中相对于dim1dim2的对角线元素被附加到形状末尾。
diff沿给定维度计算第n阶前向差分。
einsum根据爱因斯坦求和约定,沿指定维度对输入operands的元素乘积求和。
flatteninput展平为一维张量。
flip沿指定维度反转n维张量的顺序。
fliplr左右翻转张量,返回新张量。
flipud上下翻转张量,返回新张量。
kron计算inputother的克罗内克积(⊗)。
rot90在指定平面内将n维张量旋转90度。
gcd计算inputother的逐元素最大公约数(GCD)。
histc计算张量的直方图。
histogram计算张量值的直方图。
histogramdd计算张量值的多维直方图。
meshgrid根据1D输入张量创建坐标网格。
lcm计算inputother的逐元素最小公倍数(LCM)。
logcumsumexp返回input在维度dim上元素指数累积求和的对数值。
ravel返回连续的展平张量。
renorm返回归一化后的张量,其中input沿维度dim的每个子张量的p范数小于maxnorm
repeat_interleave重复张量元素。
roll沿指定维度滚动张量input
searchsortedsorted_sequence最内层维度查找索引,使得插入values对应值后仍保持排序顺序。
tensordot返回a和b在多个维度上的缩并结果。
trace返回输入2维矩阵对角线元素之和。
tril返回矩阵(2维张量)或矩阵批次input的下三角部分,结果张量out的其他元素设为0。
tril_indices返回row×col矩阵下三角部分的索引(2×N张量),首行为所有索引的行坐标,次行为列坐标。
triu返回矩阵(2维张量)或矩阵批次input的上三角部分,结果张量out的其他元素设为0。
triu_indices返回row×col矩阵上三角部分的索引(2×N张量),首行为所有索引的行坐标,次行为列坐标。
unflatten将输入张量的一个维度扩展为多个维度。
vander生成范德蒙矩阵。
view_as_real返回input作为实数张量的视图。
view_as_complex返回input作为复数张量的视图。
resolve_conjinput的共轭位为True,则返回具体化共轭的新张量,否则返回原张量。
resolve_neginput的负位为True,则返回具体化取反的新张量,否则返回原张量。

BLAS 和 LAPACK 运算

addbmm对存储在 batch1batch2 中的矩阵执行批量矩阵乘法,并带有缩减加法步骤(所有矩阵乘法沿第一维度累积)。
addmm对矩阵 mat1mat2 执行矩阵乘法。
addmv对矩阵 mat 和向量 vec 执行矩阵-向量乘法。
addr对向量 vec1vec2 执行外积,并将其加到矩阵 input 上。
baddbmmbatch1batch2 中的矩阵执行批量矩阵乘法。
bmm对存储在 inputmat2 中的矩阵执行批量矩阵乘法。
chain_matmul返回 NNN 个二维张量的矩阵乘积。
cholesky计算对称正定矩阵 AAA 或其批次的 Cholesky 分解。
cholesky_inverse给定其 Cholesky 分解,计算复 Hermitian 或实对称正定矩阵的逆。
cholesky_solve给定其 Cholesky 分解,计算具有复 Hermitian 或实对称正定 lhs 的线性方程组的解。
dot计算两个一维张量的点积。
geqrf这是一个直接调用 LAPACK 的 geqrf 的低级函数。
gertorch.outer() 的别名。
inner计算一维张量的点积。
inversetorch.linalg.inv() 的别名。
dettorch.linalg.det() 的别名。
logdet计算方阵或其批次的 log 行列式。
slogdettorch.linalg.slogdet() 的别名。
lu计算矩阵或其批次 A 的 LU 分解。
lu_solve使用来自 lu_factor() 的部分主元 LU 分解,返回线性方程组 Ax=bAx = bAx=b 的 LU 解。
lu_unpacklu_factor() 返回的 LU 分解解包为 P、L、U 矩阵。
matmul两个张量的矩阵乘积。
matrix_powertorch.linalg.matrix_power() 的别名。
matrix_exptorch.linalg.matrix_exp() 的别名。
mm对矩阵 inputmat2 执行矩阵乘法。
mv对矩阵 input 和向量 vec 执行矩阵-向量乘法。
orgqrtorch.linalg.householder_product() 的别名。
ormqr计算 Householder 矩阵与一般矩阵的矩阵-矩阵乘法。
outerinputvec2 的外积。
pinversetorch.linalg.pinv() 的别名。
qr计算矩阵或其批次 input 的 QR 分解,并返回一个命名元组 (Q, R),使得 input=QR\text{input} = Q Rinput=QR,其中 QQQ 是正交矩阵或其批次,RRR 是上三角矩阵或其批次。
svd计算矩阵或其批次 input 的奇异值分解。
svd_lowrank返回矩阵、矩阵批次或稀疏矩阵 AAA 的奇异值分解 (U, S, V),使得 A≈Udiag⁡(S)VHA \approx U \operatorname{diag}(S) V^{\text{H}}A≈Udiag(S)VH。
pca_lowrank对低秩矩阵、其批次或稀疏矩阵执行线性主成分分析 (PCA)。
lobpcg使用无矩阵 LOBPCG 方法找到对称正定广义特征值问题的 k 个最大(或最小)特征值及其对应的特征向量。
trapztorch.trapezoid() 的别名。
trapezoid沿 dim 计算梯形法则。
cumulative_trapezoid沿 dim 累积计算梯形法则。
triangular_solve解具有方上或下三角可逆矩阵 AAA 和多个右侧 bbb 的方程组。
vdot沿某一维度计算两个一维向量的点积。

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 torch.acos()
_foreach_acos_对输入列表中的每个张量应用 torch.acos()
_foreach_asin对输入列表中的每个张量应用 [torch.asin()](https://docs.pytorch.org/docs/极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 torch.acos()
_foreach_acos_对输入列表中的每个张量应用 torch.acos()
_foreach_asin对输入列表中的每个张量应用 torch.asin()
_foreach_asin_对输入列表中的每个张量应用 torch.asin()
_foreach_atan对输入列表中的每个张量应用 [torch.atan()](https://docs.pytorch.org/docs/stable/generated/torch.atan.html#极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

| [_foreach_abs](https://docs.pytorch.org/docs/stable/generated/torch._foreach_abs.html#极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
[_foreach_abs_](https://docs.pytorch.org/docs/stable/generated/torch.foreach_abs.html#torch._foreach极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 torch.acos()
_foreach_acos_对输入列表中的每个张量应用 torch.acos()
[_foreach_asin](https://docs.pytorch.org/docs/stable/generated极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

| [_foreach_abs](https://docs.pytorch.org/docs/stable/generated/torch._foreach_abs.html#torch._foreach_abs "torch._foreach极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 [`tor极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 torch.acos()
_foreach_acos_对输入列表中的每个张量应用 torch.acos()
_foreach_asin对输入列表中的每个张量应用 torch.asin()
_foreach_asin_对输入列表中的每个张量应用 torch.asin()
_foreach_atan对输入列表中的每个张量应用 torch.atan()
[`极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

| _foreach_abs | 对输入列表中的每个张量应用 [torch.abs()](https://docs.pytorch.org/docs/stable极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
_foreach_acos对输入列表中的每个张量应用 torch.acos()
_foreach_acos_对输入列表中的每个张量应用 torch.acos()
_foreach_asin对输入列表中的每个张量应用 torch.asin()
_foreach_asin_对输入列表中的每个张量应用 torch.asin()
_foreach_atan对输入列表中的每个张量应用 torch.atan()
_foreach_atan_对输入列表中的每个张量应用 torch.atan()
_foreach_ceil对输入列表中的每个张量应用 [torch.ceil()](https://docs.pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil "极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 torch.abs()
[_foreach_acos](https://docs.pytorch.org/docs/stable/generated/torch._foreach极简翻译结果:

遍历操作


警告:此API处于测试阶段,未来可能会有变更。
不支持前向模式自动微分。

_foreach_abs对输入列表中的每个张量应用 torch.abs()
_foreach_abs_对输入列表中的每个张量应用 [torch.abs()](https://docs.pytorch.org/docs/stable/g

实用工具

compiled_with_cxx11_abi返回PyTorch是否使用_GLIBCXX_USE_CXX11_ABI=1编译
result_type返回对输入张量执行算术运算后得到的torch.dtype类型
can_cast根据类型提升文档描述的PyTorch类型转换规则,判断是否允许类型转换
promote_types返回不小于type1或type2的最小尺寸和标量类型的torch.dtype
use_deterministic_algorithms设置PyTorch操作是否必须使用"确定性"算法
are_deterministic_algorithms_enabled如果全局确定性标志已开启则返回True
is_deterministic_algorithms_warn_only_enabled如果全局确定性标志设置为仅警告则返回True
set_deterministic_debug_mode设置确定性操作的调试模式
get_deterministic_debug_mode返回当前确定性操作的调试模式值
set_float32_matmul_precision设置float32矩阵乘法的内部精度
get_float32_matmul_precision返回当前float32矩阵乘法精度值
set_warn_always当此标志为False(默认)时,某些PyTorch警告可能每个进程只出现一次
get_device_module返回与给定设备关联的模块(如torch.device(‘cuda’), “mtia:0”, "xpu"等)
is_warn_always_enabled如果全局warn_always标志已开启则返回True
vmapvmap是向量化映射;vmap(func)返回一个新函数,该函数在输入的某些维度上映射func
_assertPython assert的可符号追踪包装器

符号数字


class torch.SymInt(node)[source]

类似于整型(包括魔术方法),但会重定向所有对封装节点的操作。这尤其用于在符号化形状工作流中记录符号化操作。


as_integer_ratio() 

将该整数表示为精确的整数比例

返回类型 tuple[SymInt', int]


class torch.SymFloat(node)

像一个浮点数(包括魔术方法),但会重定向所有对包装节点的操作。这尤其用于在符号化形状工作流中象征性地记录操作。


as_integer_ratio()

将这个浮点数表示为精确的整数比例

返回类型:tuple[int, int]


conjugate()

返回该浮点数的复共轭值。

返回类型:SymFloat


hex()

返回浮点数的十六进制表示形式。

返回类型 str


is_integer()

如果浮点数是整数,则返回 True。


class torch.SymBool(node)

类似于布尔类型(包括魔术方法),但会重定向所有对包装节点的操作。这尤其用于在符号化形状工作流中符号化记录操作。

与常规布尔类型不同,常规布尔运算符会强制生成额外的保护条件,而不是进行符号化求值。应改用位运算符来处理这种情况。

sym_float支持SymInt的浮点数转换工具
sym_fresh_size
sym_int支持SymInt的整数转换工具
sym_max支持SymInt的最大值工具,避免在a < b时进行分支判断
sym_min支持SymInt的最小值工具
sym_not支持SymInt的逻辑取反工具
sym_ite
sym_sum多元加法工具,对于长列表的计算速度比迭代二元加法更快

导出路径

警告:此功能为原型阶段,未来可能包含不兼容的变更。

export generated/exportdb/index


控制流

警告:此功能为原型阶段,未来可能存在破坏性变更。

cond根据条件选择执行 true_fn 或 false_fn

优化方法

compile使用TorchDynamo和指定后端优化给定模型/函数

torch.compile文档


操作符标签


class torch.Tag 

成员:

core

data_dependent_output

dynamic_output_shape

flexible_layout

generated

inplace_view

maybe_aliasing_or_mutating

needs_fixed_stride_order

nondeterministic_bitwise

nondeterministic_seeded

pointwise

pt2_compliant_tag

view_copy


torch.nn

以下是构建图模型的基本组件:

torch.nn

Buffer一种不应被视为模型参数的张量类型
Parameter一种需要被视为模块参数的张量类型
UninitializedParameter未初始化的参数
UninitializedBuffer未初始化的缓冲区

容器模块

Module所有神经网络模块的基类
Sequential顺序容器
ModuleList以列表形式存储子模块
ModuleDict以字典形式存储子模块
ParameterList以列表形式存储参数
ParameterDict以字典形式存储参数

模块全局钩子

register_module_forward_pre_hook为所有模块注册前向预处理钩子
register_module_forward_hook为所有模块注册全局前向钩子
register_module_backward_hook为所有模块注册反向传播钩子
register_module_full_backward_pre_hook为所有模块注册反向预处理钩子
register_module_full_backward_hook为所有模块注册完整反向传播钩子
register_module_buffer_registration_hook为所有模块注册缓冲区注册钩子
register_module_module_registration_hook为所有模块注册子模块注册钩子
register_module_parameter_registration_hook为所有模块注册参数注册钩子

卷积层

nn.Conv1d对由多个输入平面组成的输入信号进行一维卷积运算
nn.Conv2d对由多个输入平面组成的输入信号进行二维卷积运算
nn.Conv3d对由多个输入平面组成的输入信号进行三维卷积运算
nn.ConvTranspose1d对由多个输入平面组成的输入图像进行一维转置卷积运算
nn.ConvTranspose2d对由多个输入平面组成的输入图像进行二维转置卷积运算
nn.ConvTranspose3d对由多个输入平面组成的输入图像进行三维转置卷积运算
nn.LazyConv1d具有in_channels参数延迟初始化特性的torch.nn.Conv1d模块
nn.LazyConv2d具有in_channels参数延迟初始化特性的torch.nn.Conv2d模块
nn.LazyConv3d具有in_channels参数延迟初始化特性的torch.nn.Conv3d模块
nn.LazyConvTranspose1d具有in_channels参数延迟初始化特性的torch.nn.ConvTranspose1d模块
nn.LazyConvTranspose2d具有in_channels参数延迟初始化特性的torch.nn.ConvTranspose2d模块
nn.LazyConvTranspose3d具有in_channels参数延迟初始化特性的torch.nn.ConvTranspose3d模块
nn.Unfold从批处理输入张量中提取滑动局部块
nn.Fold将滑动局部块数组合并为一个包含张量

池化层

nn.MaxPool1d对由多个输入平面组成的输入信号应用一维最大池化。
nn.MaxPool2d对由多个输入平面组成的输入信号应用二维最大池化。
nn.MaxPool3d对由多个输入平面组成的输入信号应用三维最大池化。
nn.MaxUnpool1d计算 MaxPool1d 的部分逆运算。
nn.MaxUnpool2d计算 MaxPool2d 的部分逆运算。
nn.MaxUnpool3d计算 MaxPool3d 的部分逆运算。
nn.AvgPool1d对由多个输入平面组成的输入信号应用一维平均池化。
nn.AvgPool2d对由多个输入平面组成的输入信号应用二维平均池化。
nn.AvgPool3d对由多个输入平面组成的输入信号应用三维平均池化。
nn.FractionalMaxPool2d对由多个输入平面组成的输入信号应用二维分数最大池化。
nn.FractionalMaxPool3d对由多个输入平面组成的输入信号应用三维分数最大池化。
nn.LPPool1d对由多个输入平面组成的输入信号应用一维幂平均池化。
nn.LPPool2d对由多个输入平面组成的输入信号应用二维幂平均池化。
nn.LPPool3d对由多个输入平面组成的输入信号应用三维幂平均池化。
nn.AdaptiveMaxPool1d对由多个输入平面组成的输入信号应用一维自适应最大池化。
nn.AdaptiveMaxPool2d对由多个输入平面组成的输入信号应用二维自适应最大池化。
nn.AdaptiveMaxPool3d对由多个输入平面组成的输入信号应用三维自适应最大池化。
nn.AdaptiveAvgPool1d对由多个输入平面组成的输入信号应用一维自适应平均池化。
nn.AdaptiveAvgPool2d对由多个输入平面组成的输入信号应用二维自适应平均池化。
nn.AdaptiveAvgPool3d对由多个输入平面组成的输入信号应用三维自适应平均池化。

填充层

nn.ReflectionPad1d使用输入边界的反射来填充输入张量
nn.ReflectionPad2d使用输入边界的反射来填充输入张量
nn.ReflectionPad3d使用输入边界的反射来填充输入张量
nn.ReplicationPad1d使用输入边界的复制来填充输入张量
nn.ReplicationPad2d使用输入边界的复制来填充输入张量
nn.ReplicationPad3d使用输入边界的复制来填充输入张量
nn.ZeroPad1d用零值填充输入张量边界
nn.ZeroPad2d用零值填充输入张量边界
nn.ZeroPad3d用零值填充输入张量边界
nn.ConstantPad1d用常数值填充输入张量边界
nn.ConstantPad2d用常数值填充输入张量边界
nn.ConstantPad3d用常数值填充输入张量边界
nn.CircularPad1d使用输入边界的循环填充来填充输入张量
nn.CircularPad2d使用输入边界的循环填充来填充输入张量
nn.CircularPad3d使用输入边界的循环填充来填充输入张量

非线性激活函数(加权求和与非线性变换)

nn.ELU逐元素应用指数线性单元(ELU)函数
nn.Hardshrink逐元素应用硬收缩(Hardshrink)函数
nn.Hardsigmoid逐元素应用硬Sigmoid函数
nn.Hardtanh逐元素应用HardTanh函数
nn.Hardswish逐元素应用Hardswish函数
nn.LeakyReLU逐元素应用LeakyReLU函数
nn.LogSigmoid逐元素应用Logsigmoid函数
nn.MultiheadAttention使模型能够共同关注来自不同表示子空间的信息
nn.PReLU逐元素应用PReLU函数
nn.ReLU逐元素应用修正线性单元函数
nn.ReLU6逐元素应用ReLU6函数
nn.RReLU逐元素应用随机泄漏修正线性单元函数
nn.SELU逐元素应用SELU函数
nn.CELU逐元素应用CELU函数
nn.GELU应用高斯误差线性单元函数
nn.Sigmoid逐元素应用Sigmoid函数
nn.SiLU逐元素应用Sigmoid线性单元(SiLU)函数
nn.Mish逐元素应用Mish函数
nn.Softplus逐元素应用Softplus函数
nn.Softshrink逐元素应用软收缩函数
nn.Softsign逐元素应用Softsign函数
nn.Tanh逐元素应用双曲正切(Tanh)函数
nn.Tanhshrink逐元素应用Tanhshrink函数
nn.Threshold对输入张量的每个元素进行阈值处理
nn.GLU应用门控线性单元函数

非线性激活函数(其他)

nn.Softmin对n维输入张量应用Softmin函数
nn.Softmax对n维输入张量应用Softmax函数
nn.Softmax2d在每个空间位置上对特征应用SoftMax
nn.LogSoftmax对n维输入张量应用log⁡(Softmax(x))\log(\text{Softmax}(x))log(Softmax(x))函数
nn.AdaptiveLogSoftmaxWithLoss高效的softmax近似方法

归一化层

nn.BatchNorm1d对2D或3D输入应用批量归一化
nn.BatchNorm2d对4D输入应用批量归一化
nn.BatchNorm3d对5D输入应用批量归一化
nn.LazyBatchNorm1d具有延迟初始化功能的torch.nn.BatchNorm1d模块
nn.LazyBatchNorm2d具有延迟初始化功能的torch.nn.BatchNorm2d模块
nn.LazyBatchNorm3d具有延迟初始化功能的torch.nn.BatchNorm3d模块
nn.GroupNorm对小批量输入应用组归一化
nn.SyncBatchNorm对N维输入应用批量归一化
nn.InstanceNorm1d应用实例归一化
nn.InstanceNorm2d应用实例归一化
nn.InstanceNorm3d应用实例归一化
nn.LazyInstanceNorm1d具有num_features参数延迟初始化功能的torch.nn.InstanceNorm1d模块
nn.LazyInstanceNorm2d具有num_features参数延迟初始化功能的torch.nn.InstanceNorm2d模块
nn.LazyInstanceNorm3d具有num_features参数延迟初始化功能的torch.nn.InstanceNorm3d模块
nn.LayerNorm对小批量输入应用层归一化
nn.LocalResponseNorm对输入信号应用局部响应归一化
nn.RMSNorm对小批量输入应用均方根层归一化

循环神经网络层

nn.RNNBaseRNN模块的基类(包括RNN、LSTM、GRU)
nn.RNN对输入序列应用多层Elman RNN,使用tanh或ReLU非线性激活函数
nn.LSTM对输入序列应用多层长短期记忆(LSTM)循环神经网络
nn.GRU对输入序列应用多层门控循环单元(GRU)网络
nn.RNNCell具有tanh或ReLU非线性激活的Elman RNN单元
nn.LSTMCell长短期记忆(LSTM)单元
nn.GRUCell门控循环单元(GRU)

Transformer 层

nn.TransformerTransformer 模型
nn.TransformerEncoderTransformerEncoder 由 N 个编码器层堆叠而成
nn.TransformerDecoderTransformerDecoder 由 N 个解码器层堆叠而成
nn.TransformerEncoderLayerTransformerEncoderLayer 由自注意力机制和前馈网络组成
nn.TransformerDecoderLayerTransformerDecoderLayer 由自注意力机制、多头注意力机制和前馈网络组成

线性层

nn.Identity一个参数无关的占位恒等运算符
nn.Linear对输入数据进行仿射线性变换:y=xAT+by = xA^T + by=xAT+b
nn.Bilinear对输入数据进行双线性变换:y=x1TAx2+by = x_1^T A x_2 + by=x1T​Ax2​+b
nn.LazyLinear一个自动推断输入特征数(in_features)的torch.nn.Linear模块

Dropout 层

nn.Dropout在训练过程中,以概率 p 随机将输入张量的部分元素置零。
nn.Dropout1d随机将整个通道置零。
nn.Dropout2d随机将整个通道置零。
nn.Dropout3d随机将整个通道置零。
nn.AlphaDropout对输入应用 Alpha Dropout。
nn.FeatureAlphaDropout随机屏蔽整个通道。

稀疏层

nn.Embedding一个简单的查找表,用于存储固定字典和大小的嵌入向量。
nn.EmbeddingBag计算嵌入向量"包"的和或均值,而无需实例化中间嵌入向量。

距离函数

nn.CosineSimilarity返回x1和x2沿指定维度的余弦相似度
nn.PairwiseDistance计算输入向量之间的成对距离,或输入矩阵列之间的成对距离

损失函数

nn.L1Loss创建一个衡量输入x和目标y之间平均绝对误差(MAE)的损失函数。
nn.MSELoss创建一个衡量输入x和目标y之间均方误差(平方L2范数)的损失函数。
nn.CrossEntropyLoss计算输入logits和目标之间的交叉熵损失。
nn.CTCLoss连接时序分类损失(Connectionist Temporal Classification loss)。
nn.NLLLoss负对数似然损失。
nn.PoissonNLLLoss目标服从泊松分布的负对数似然损失。
nn.GaussianNLLLoss高斯负对数似然损失。
nn.KLDivLossKL散度损失(Kullback-Leibler divergence loss)。
nn.BCELoss创建一个衡量目标与输入概率之间二元交叉熵的损失函数。
nn.BCEWithLogitsLoss将Sigmoid层和BCELoss结合在一个类中的损失函数。
nn.MarginRankingLoss创建一个衡量给定输入x1、x2(两个1D mini-batch或0D张量)和标签y(包含1或-1的1D mini-batch或0D张量)的损失函数。
nn.HingeEmbeddingLoss衡量给定输入张量x和标签张量y(包含1或-1)的损失。
nn.MultiLabelMarginLoss创建一个优化输入x(2D mini-batch张量)和输出y(目标类别索引的2D张量)之间多类多分类铰链损失(基于边距的损失)的损失函数。
nn.HuberLoss创建一个在元素级绝对误差低于delta时使用平方项,否则使用delta缩放L1项的损失函数。
nn.SmoothL1Loss创建一个在元素级绝对误差低于beta时使用平方项,否则使用L1项的损失函数。
nn.SoftMarginLoss创建一个优化输入张量x和目标张量y(包含1或-1)之间二分类逻辑损失的损失函数。
nn.MultiLabelSoftMarginLoss创建一个基于最大熵优化输入x和大小(N,C)的目标y之间多标签一对多损失的损失函数。
nn.CosineEmbeddingLoss创建一个衡量给定输入张量x1、x2和值为1或-1的标签张量y的损失函数。
nn.MultiMarginLoss创建一个优化输入x(2D mini-batch张量)和输出y(目标类别索引的1D张量,0≤y≤x.size(1)−1)之间多类分类铰链损失(基于边距的损失)的损失函数。
nn.TripletMarginLoss创建一个衡量给定输入张量x1、x2、x3和大于0的边距值的三元组损失的损失函数。
nn.TripletMarginWithDistanceLoss创建一个衡量给定输入张量a、p、n(分别表示锚点、正例和负例)以及用于计算锚点与正例(“正距离”)和锚点与负例(“负距离”)之间关系的非负实值函数(“距离函数”)的三元组损失的损失函数。

视觉层

nn.PixelShuffle根据上采样因子重新排列张量中的元素
nn.PixelUnshuffle反转PixelShuffle操作
nn.Upsample对给定的多通道1D(时序)、2D(空间)或3D(体积)数据进行上采样
nn.UpsamplingNearest2d对由多个输入通道组成的输入信号应用2D最近邻上采样
nn.UpsamplingBilinear2d对由多个输入通道组成的输入信号应用2D双线性上采样

通道混洗层

nn.ChannelShuffle对张量中的通道进行分组并重新排列

数据并行层(多GPU,分布式)

nn.DataParallel在模块级别实现数据并行。
nn.parallel.DistributedDataParallel基于torch.distributed在模块级别实现分布式数据并行。

实用工具

来自 torch.nn.utils 模块的实用函数:

参数梯度裁剪工具

clip_grad_norm_对一组可迭代参数的梯度范数进行裁剪
clip_grad_norm对一组可迭代参数的梯度范数进行裁剪
clip_grad_value_按指定值裁剪一组可迭代参数的梯度
get_total_norm计算一组张量的范数
clip_grads_with_norm_根据预计算的总范数和期望的最大范数缩放一组参数的梯度

模块参数扁平化与反扁平化工具

parameters_to_vector将一组参数展平为单个向量
vector_to_parameters将向量切片复制到一组参数中

模块与批归一化融合工具

fuse_conv_bn_eval将卷积模块和批归一化模块融合为新的卷积模块
fuse_conv_bn_weights将卷积模块参数和批归一化模块参数融合为新的卷积模块参数
fuse_linear_bn_eval将线性模块和批归一化模块融合为新的线性模块
fuse_linear_bn_weights将线性模块参数和批归一化模块参数融合为新的线性模块参数

模块参数内存格式转换工具

convert_conv2d_weight_memory_format转换 nn.Conv2d.weightmemory_format
convert_conv3d_weight_memory_format转换 nn.Conv3d.weightmemory_format,该转换会递归应用到嵌套的 nn.Module 包括 module

权重归一化应用与移除工具

weight_norm对给定模块中的参数应用权重归一化
remove_weight_norm从模块中移除权重归一化重参数化
spectral_norm对给定模块中的参数应用谱归一化
remove_spectral_norm从模块中移除谱归一化重参数化

模块参数初始化工具

| skip_init | 给定模块类对象和参数,实例化模块但不初始化参数/缓冲区 |


模块参数剪枝工具类与函数

prune.BasePruningMethod创建新剪枝技术的抽象基类
prune.PruningContainer包含一系列剪枝方法的容器,用于迭代剪枝
prune.Identity不剪枝任何单元但生成带有全1掩码的剪枝参数化的实用方法
prune.RandomUnstructured随机剪枝张量中当前未剪枝的单元
prune.L1Unstructured通过置零L1范数最小的单元来剪枝张量
prune.RandomStructured随机剪枝张量中当前未剪枝的整个通道
prune.LnStructured基于Ln范数剪枝张量中当前未剪枝的整个通道
prune.CustomFromMask
prune.identity应用剪枝重参数化但不剪枝任何单元
prune.random_unstructured通过移除随机未剪枝单元来剪枝张量
prune.l1_unstructured通过移除L1范数最小的单元来剪枝张量
prune.random_structured通过沿指定维度移除随机通道来剪枝张量
prune.ln_structured通过沿指定维度移除Ln范数最小的通道来剪枝张量
prune.global_unstructured通过应用指定的pruning_method全局剪枝parameters中所有参数对应的张量
prune.custom_from_mask通过应用预计算掩码mask来剪枝module中名为name的参数对应的张量
prune.remove从模块中移除剪枝重参数化并从前向钩子中移除剪枝方法
prune.is_pruned通过查找剪枝前钩子检查模块是否被剪枝

使用torch.nn.utils.parameterize.register_parametrization()新参数化功能实现的参数化

parametrizations.orthogonal对矩阵或矩阵批次应用正交或酉参数化
parametrizations.weight_norm对给定模块中的参数应用权重归一化
parametrizations.spectral_norm对给定模块中的参数应用谱归一化

为现有模块上的张量参数化的实用函数

注意:这些函数可用于通过特定函数将给定参数或缓冲区从输入空间映射到参数化空间。它们不是将对象转换为参数的参数化。有关如何实现自定义参数化的更多信息,请参阅参数化教程。

parametrize.register_parametrization为模块中的张量注册参数化
parametrize.remove_parametrizations移除模块中张量的参数化
parametrize.cached启用通过register_parametrization()注册的参数化内部缓存系统的上下文管理器
parametrize.is_parametrized判断模块是否具有参数化
parametrize.ParametrizationList顺序容器,用于保存和管理参数化torch.nn.Module的原始参数或缓冲区

以无状态方式调用给定模块的实用函数

| stateless.functional_call | 通过用提供的参数和缓冲区替换模块参数和缓冲区来执行功能调用 |


其他模块中的实用函数

nn.utils.rnn.PackedSequence保存打包序列的数据和batch_sizes列表
nn.utils.rnn.pack_padded_sequence打包包含可变长度填充序列的张量
nn.utils.rnn.pad_packed_sequence对打包的可变长度序列批次进行填充
nn.utils.rnn.pad_sequencepadding_value填充可变长度张量列表
nn.utils.rnn.pack_sequence打包可变长度张量列表
nn.utils.rnn.unpack_sequence将PackedSequence解包为可变长度张量列表
nn.utils.rnn.unpad_sequence将填充张量解包为可变长度张量列表
nn.Flatten将连续范围的维度展平为张量
nn.Unflatten将张量维度展开为期望形状

量化函数

量化是指以低于浮点精度的位宽执行计算和存储张量的技术。PyTorch 同时支持逐张量和逐通道的非对称线性量化。要了解更多关于如何在 PyTorch 中使用量化函数的信息,请参阅量化文档。


惰性模块初始化

nn.modules.lazy.LazyModuleMixin用于实现参数惰性初始化的模块混合类(也称为"惰性模块")

别名

以下是 torch.nn 中对应模块的别名:

nn.modules.normalization.RMSNorm对输入的小批量数据应用均方根层归一化。

torch.nn.functional


卷积函数

conv1d对由多个输入平面组成的输入信号应用一维卷积运算
conv2d对由多个输入平面组成的输入图像应用二维卷积运算
conv3d对由多个输入平面组成的输入图像应用三维卷积运算
conv_transpose1d对由多个输入平面组成的输入信号应用一维转置卷积算子(有时也称为"反卷积")
conv_transpose2d对由多个输入平面组成的输入图像应用二维转置卷积算子(有时也称为"反卷积")
conv_transpose3d对由多个输入平面组成的输入图像应用三维转置卷积算子(有时也称为"反卷积")
unfold从批处理输入张量中提取滑动局部块
fold将滑动局部块数组合并为一个包含张量

池化函数

avg_pool1d对由多个输入平面组成的输入信号应用一维平均池化。
avg_pool2d在kH×kW区域以步长sH×sW进行二维平均池化操作。
avg_pool3d在kT×kH×kW区域以步长sT×sH×sW进行三维平均池化操作。
max_pool1d对由多个输入平面组成的输入信号应用一维最大池化。
max_pool2d对由多个输入平面组成的输入信号应用二维最大池化。
max_pool3d对由多个输入平面组成的输入信号应用三维最大池化。
max_unpool1d计算MaxPool1d的部分逆运算。
max_unpool2d计算MaxPool2d的部分逆运算。
max_unpool3d计算MaxPool3d的部分逆运算。
lp_pool1d对由多个输入平面组成的输入信号应用一维幂平均池化。
lp_pool2d对由多个输入平面组成的输入信号应用二维幂平均池化。
lp_pool3d对由多个输入平面组成的输入信号应用三维幂平均池化。
adaptive_max_pool1d对由多个输入平面组成的输入信号应用一维自适应最大池化。
adaptive_max_pool2d对由多个输入平面组成的输入信号应用二维自适应最大池化。
adaptive_max_pool3d对由多个输入平面组成的输入信号应用三维自适应最大池化。
adaptive_avg_pool1d对由多个输入平面组成的输入信号应用一维自适应平均池化。
adaptive_avg_pool2d对由多个输入平面组成的输入信号应用二维自适应平均池化。
adaptive_avg_pool3d对由多个输入平面组成的输入信号应用三维自适应平均池化。
fractional_max_pool2d对由多个输入平面组成的输入信号应用二维分数最大池化。
fractional_max_pool3d对由多个输入平面组成的输入信号应用三维分数最大池化。

注意力机制

torch.nn.attention.bias 模块包含专为 scaled_dot_product_attention 设计的注意力偏置项。

scaled_dot_product_attentionscaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,

非线性激活函数

threshold对输入张量的每个元素应用阈值处理
threshold_threshold() 的原位操作版本
relu逐元素应用修正线性单元函数
relu_relu() 的原位操作版本
hardtanh逐元素应用 HardTanh 函数
hardtanh_hardtanh() 的原位操作版本
hardswish逐元素应用 hardswish 函数
relu6逐元素应用 ReLU6(x)=min⁡(max⁡(0,x),6)\text{ReLU6}(x) = \min(\max(0,x), 6)ReLU6(x)=min(max(0,x),6) 函数
elu逐元素应用指数线性单元(ELU)函数
elu_elu() 的原位操作版本
selu逐元素应用 SELU(x)=scale∗(max⁡(0,x)+min⁡(0,α∗(exp⁡(x)−1)))\text{SELU}(x) = scale * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))SELU(x)=scale∗(max(0,x)+min(0,α∗(exp(x)−1))) 函数,其中 α=1.6732632423543772848170429916717\alpha=1.6732632423543772848170429916717α=1.6732632423543772848170429916717,scale=1.0507009873554804934193349852946scale=1.0507009873554804934193349852946scale=1.0507009873554804934193349852946
celu逐元素应用 CELU(x)=max⁡(0,x)+min⁡(0,α∗(exp⁡(x/α)−1))\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))CELU(x)=max(0,x)+min(0,α∗(exp(x/α)−1)) 函数
leaky_relu逐元素应用 LeakyReLU(x)=max⁡(0,x)+negative_slope∗min⁡(0,x)\text{LeakyReLU}(x) = \max(0, x) + \text{negative_slope} * \min(0, x)LeakyReLU(x)=max(0,x)+negative_slope∗min(0,x) 函数
leaky_relu_leaky_relu() 的原位操作版本
prelu逐元素应用 PReLU(x)=max⁡(0,x)+weight∗min⁡(0,x)\text{PReLU}(x) = \max(0,x) + \text{weight} * \min(0,x)PReLU(x)=max(0,x)+weight∗min(0,x) 函数,其中 weight 是可学习参数
rrelu随机泄漏 ReLU
rrelu_rrelu() 的原位操作版本
glu门控线性单元
gelu当 approximate 参数为 ‘none’ 时,逐元素应用 GELU(x)=x∗Φ(x)\text{GELU}(x) = x * \Phi(x)GELU(x)=x∗Φ(x) 函数
logsigmoid逐元素应用 LogSigmoid(xi)=log⁡(11+exp⁡(−xi))\text{LogSigmoid}(x_i) = \log \left(\frac{1}{1 + \exp(-x_i)}\right)LogSigmoid(xi​)=log(1+exp(−xi​)1​) 函数
hardshrink逐元素应用硬收缩函数
tanhshrink逐元素应用 Tanhshrink(x)=x−Tanh(x)\text{Tanhshrink}(x) = x - \text{Tanh}(x)Tanhshrink(x)=x−Tanh(x) 函数
softsign逐元素应用 SoftSign(x)=x1+∣x∣\text{SoftSign}(x) = \frac{x}{1 +
softplus逐元素应用 Softplus(x)=1β∗log⁡(1+exp⁡(β∗x))\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))Softplus(x)=β1​∗log(1+exp(β∗x)) 函数
softmin应用 softmin 函数
softmax应用 softmax 函数
softshrink逐元素应用软收缩函数
gumbel_softmax从 Gumbel-Softmax 分布(链接1 链接2)采样并可选离散化
log_softmax应用 softmax 后接对数运算
tanh逐元素应用 Tanh(x)=tanh⁡(x)=exp⁡(x)−exp⁡(−x)exp⁡(x)+exp⁡(−x)\text{Tanh}(极x) = \tanh(x) = \frac{\exp(x) - \exp(-x)}{\exp(x) + \exp(-x)}Tanh(x)=tanh(x)=exp(x)+exp(−x)exp(x)−exp(−x)​ 函数
sigmoid逐元素应用 Sigmoid(x)=11+exp⁡(−x)\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}Sigmoid(x)=1+exp(−x)1​ 函数
hardsigmoid逐元素应用 Hardsigmoid 函数
silu逐元素应用 Sigmoid 线性单元(SiLU)函数
mish逐元素应用 Mish 函数
batch_norm对批量数据中的每个通道应用批量归一化
group_norm对最后若干维度应用组归一化
instance_norm对批量中每个数据样本的每个通道独立应用实例归一化
layer_norm对最后若干维度应用层归一化
local_response_norm对输入信号应用局部响应归一化
rms_norm应用均方根层归一化
normalize对指定维度执行 LpL_pLp​ 归一化

线性函数

linear对输入数据应用线性变换:y=xAT+by = xA^T + by=xAT+b
bilinear对输入数据应用双线性变换:y=x1TAx2+by = x_1^T A x_2 + by=x1T​Ax2​+b

Dropout 函数

dropout在训练过程中,以概率 p 随机将输入张量的部分元素置零。
alpha_dropout对输入应用 alpha dropout。
feature_alpha_dropout随机屏蔽整个通道(通道即特征图)。
dropout1d随机将整个通道置零(通道为 1D 特征图)。
dropout2d随机将整个通道置零(通道为 2D 特征图)。
dropout3d随机将整个通道置零(通道为 3D 特征图)。

稀疏函数

embedding生成一个简单的查找表,用于在固定字典和尺寸中查找嵌入向量。
embedding_bag计算嵌入向量包的和、平均值或最大值。
one_hot接收形状为()的LongTensor索引值,返回形状为(, num_classes)的张量,该张量除最后一维索引与输入张量对应值匹配的位置为1外,其余位置均为0。

距离函数

pairwise_distance详情参见 torch.nn.PairwiseDistance
cosine_similarity返回 x1x2 沿指定维度的余弦相似度
pdist计算输入中每对行向量之间的 p-范数距离

损失函数

binary_cross_entropy计算目标值与输入概率之间的二元交叉熵
binary_cross_entropy_with_logits计算目标值与输入logits之间的二元交叉熵
poisson_nll_loss泊松负对数似然损失
cosine_embedding_loss详见CosineEmbeddingLoss
cross_entropy计算输入logits与目标值之间的交叉熵损失
ctc_loss应用连接时序分类损失
gaussian_nll_loss高斯负对数似然损失
hinge_embedding_loss详见HingeEmbeddingLoss
[kl_div](https://docs.pytorch.org/docs/stable/generated/t torch.nn.functional.kl_div.html#torch.nn.functional.kl_div “torch.nn.functional.kl_div”)计算KL散度损失
l1_loss计算元素级绝对差值的均值
mse_loss计算元素级均方误差(支持加权)
margin_ranking_loss详见MarginRankingLoss
multilabel_margin_loss详见MultiLabelMarginLoss
multilabel_soft_margin_loss详见MultiLabelSoftMarginLoss
multi_margin_loss详见MultiMarginLoss
nll_loss计算负对数似然损失
huber_loss计算Huber损失(支持加权)
smooth_l1_loss计算平滑L1损失
soft_margin_loss详见SoftMarginLoss
triplet_margin_loss计算输入张量与大于0的边界值之间的三元组损失
triplet_margin_with_distance_loss使用自定义距离函数计算输入张量的三元组边界损失

视觉函数

pixel_shuffle将形状为 ( ∗ , C × r 2 , H , W ) (∗,C×r^2,H,W) (,C×r2,H,W)的张量元素重新排列为形状 ( ∗ , C , H × r , W × r ) (∗,C, H × r, W × r) (,C,H×r,W×r)的张量,其中r为upscale_factor
pixel_unshuffle通过将形状为 ( ∗ , C , H × r , W × r ) (∗,C, H × r, W × r) (,C,H×r,W×r)的张量元素重新排列为形状 ( ∗ , C × r 2 , H , W ) (∗,C×r^2,H,W) (,C×r2,H,W)的张量,来逆转PixelShuffle操作,其中r为downscale_factor
pad对张量进行填充。
interpolate对输入进行下采样/上采样。
upsample对输入进行上采样。
upsample_nearest使用最近邻像素值对输入进行上采样。
upsample_bilinear使用双线性上采样对输入进行上采样。
grid_sample计算网格采样。
affine_grid给定一批仿射矩阵theta,生成2D或3D流场(采样网格)。

DataParallel 功能(多GPU,分布式)


data_parallel

torch.nn.parallel.data_parallel在指定设备ID列表(device_ids)中的多个GPU上并行评估模块(input)。

torch.Tensor

torch.Tensor 是一个包含单一数据类型元素的多维矩阵。


数据类型

Torch 定义了以下数据类型的张量:

数据类型dtype
32位浮点数torch.float32torch.float
64位浮点数torch.float64torch.double
16位浮点数 [1torch.float16torch.half
16位浮点数 [2torch.bfloat16
32位复数torch.complex32torch.chalf
64位复数torch.complex64torch.cfloat
128位复数torch.complex128torch.cdouble
8位整数(无符号)torch.uint8
16位整数(无符号)torch.uint16(有限支持)[4
32位整数(无符号)torch.uint32(有限支持)[4
64位整数(无符号)torch.uint64(有限支持)[4
8位整数(有符号)torch.int8
16位整数(有符号)torch.int16torch.short
32位整数(有符号)torch.int32torch.int
64位整数(有符号)torch.int64torch.long
布尔值torch.bool
量化8位整数(无符号)torch.quint8
量化8位整数(有符号)torch.qint8
量化32位整数(有符号)torch.qint32
量化4位整数(无符号)[3torch.quint4x2
8位浮点数,e4m3 [5torch.float8_e4m3fn(有限支持)
8位浮点数,e5m2 [5torch.float8_e5m2(有限支持)

[1
有时称为 binary16:使用1位符号、5位指数和10位尾数。在精度比范围更重要时很有用。

[2
有时称为 Brain Floating Point:使用1位符号、8位指数和7位尾数。在范围更重要时很有用,因为它与 float32 具有相同数量的指数位。

[3
量化4位整数存储为8位有符号整数。目前仅在 EmbeddingBag 操作符中支持。

4([1 ,[2 ,[3 )
uint8 外的无符号类型目前计划仅在 eager 模式下提供有限支持(它们主要用于辅助 torch.compile 的使用);如果需要 eager 支持且不需要额外的范围,建议使用其有符号变体。详情请参阅 https://github.com/pytorch/pytorch/issues/58734。

5([1 ,[2 )
torch.float8_e4m3fntorch.float8_e5m2 实现了来自 https://arxiv.org/abs/2209.05433 的8位浮点数规范。操作支持非常有限。

为了向后兼容,我们支持以下这些数据类型的替代类名:

数据类型CPU 张量GPU 张量
32位浮点数torch.FloatTensortorch.cuda.FloatTensor
64位浮点数torch.DoubleTensortorch.cuda.DoubleTensor
16位浮点数torch.HalfTensortorch.cuda.HalfTensor
16位浮点数torch.BFloat16Tensortorch.cuda.BFloat16Tensor
8位整数(无符号)torch.ByteTensortorch.cuda.ByteTensor
8位整数(有符号)torch.CharTensortorch.cuda.CharTensor
16位整数(有符号)torch.ShortTensortorch.cuda.ShortTensor
32位整数(有符号)torch.IntTensortorch.cuda.IntTensor
64位整数(有符号)torch.LongTensortorch.cuda.LongTensor
布尔值torch.BoolTensortorch.cuda.BoolTensor

然而,为了构造张量,我们建议使用工厂函数如 torch.empty() 并指定 dtype 参数。torch.Tensor 构造函数是默认张量类型(torch.FloatTensor)的别名。


初始化与基础操作

可以通过 Python 的 list 或序列使用 torch.tensor() 构造函数来构建张量:

>>> torch.tensor([[1., -1.], [1., -1.]])
tensor([[1.0000, -1.0000], [1.0000, -1.0000]])
>>> torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))
tensor([[1, 2, 3], [4, 5, 6]])

警告:torch.tensor() 总是会复制 data。如果你已经有一个 Tensor data 并且只想修改它的 requires_grad 标志,请使用 requires_grad_()detach() 来避免复制操作。

如果你有一个 numpy 数组并且希望避免复制,请使用 torch.as_tensor()

可以通过向构造函数或张量创建操作传递 torch.dtype 和/或 torch.device 来构造特定数据类型的张量:

>>> torch.zeros([2, 4], dtype=torch.int32)
tensor([[0, 0, 0, 0], [0, 0, 0, 0]], dtype=torch.int32)
>>> cuda0 = torch.device('cuda:0')
>>> torch.ones([2, 4], dtype=torch.float64, device=cuda0)
tensor([[1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000]], dtype=torch.float64, device='cuda:0')

有关构建张量的更多信息,请参阅创建操作。

可以使用Python的索引和切片符号来访问和修改张量的内容:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> print(x[1][2])
tensor(6)
>>> x[0][1] = 8
>>> print(x)
tensor([[1, 8, 3], [4, 5, 6]])

使用 torch.Tensor.item() 从包含单个值的张量中获取 Python 数值:

>>> x = torch.tensor([[1]])
>>> x
tensor([[1]])
>>> x.item()
1
>>> x = torch.tensor(2.5)
>>> x
tensor(2.5000)
>>> x.item()
2.5

有关索引的更多信息,请参阅索引、切片、连接和变异操作

可以通过设置requires_grad=True来创建张量,这样torch.autograd会记录对其的操作以实现自动微分。


>>> x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
>>> out = x.pow(2).sum()
>>> out.backward()
>>> x.grad
tensor([[2.0000, -2.0000], [2.0000, 2.0000]])

每个张量都有一个关联的 torch.Storage,用于存储其数据。

张量类还提供了存储的多维跨步视图,并定义了基于它的数值运算。

注意:有关张量视图的更多信息,请参阅张量视图。

注意:关于 torch.Tensortorch.dtypetorch.devicetorch.layout 属性的更多信息,请参阅张量属性。

注意:会改变张量的方法以下划线后缀标记。例如,torch.FloatTensor.abs_() 会就地计算绝对值并返回修改后的张量,而 torch.FloatTensor.abs() 则会在新张量中计算结果。

注意:要更改现有张量的 torch.device 和/或 torch.dtype,可以考虑使用张量的 to() 方法。

警告:当前 torch.Tensor 的实现引入了内存开销,因此在处理大量小张量的应用中可能导致意外的高内存使用。如果遇到这种情况,建议使用单个大型结构。


Tensor 类参考


class torch.Tensor 

根据不同的使用场景,创建张量主要有以下几种方式:

  • 若要从现有数据创建张量,请使用 torch.tensor()
  • 若要创建指定大小的张量,请使用 torch.* 张量创建操作(参见创建操作)。
  • 若要创建与另一个张量大小相同(且类型相似)的张量,请使用 torch.*_like 张量创建操作(参见创建操作)。
  • 若要创建类型相似但大小不同的张量,请使用 tensor.new_* 创建操作。
  • 存在一个遗留构造函数 torch.Tensor,不建议继续使用。请改用 torch.tensor()

Tensor.__init__(self, data)

该构造函数已弃用,建议改用 torch.tensor()

此构造函数的行为取决于 data 的类型:

  • 如果 data 是 Tensor,则返回原始 Tensor 的别名。与 torch.tensor() 不同,此操作会跟踪自动微分并将梯度传播到原始 Tensor。对于这种 data 类型不支持 device 关键字参数。
  • 如果 data 是序列或嵌套序列,则创建一个默认数据类型(通常是 torch.float32)的张量,其数据为序列中的值,必要时执行类型转换。值得注意的是,此构造函数与 torch.tensor() 的区别在于,即使输入全是整数,此构造函数也会始终构造浮点张量。
  • 如果 datatorch.Size,则返回一个该大小的空张量。

此构造函数不支持显式指定返回张量的 dtypedevice。建议使用 torch.tensor(),它提供了此功能。

参数:

data (array_like): 用于构造张量的数据。

关键字参数:

device (torch.device, 可选): 返回张量的目标设备。默认值:如果为 None,则与此张量相同的 torch.device。

Tensor.T

返回此张量的维度反转视图。

如果 x 的维度数为 n,则 x.T 等价于 x.permute(n-1, n-2, ..., 0)

警告: 在非二维张量上使用 Tensor.T() 来反转形状的做法已弃用,未来版本中将抛出错误。对于矩阵批量的转置,请考虑使用 mT;对于张量维度的反转,请使用 x.permute(torch.arange(x.ndim - 1, -1, -1))

Tensor.H

返回矩阵(二维张量)的共轭转置视图。

对于复数矩阵,x.H 等价于 x.transpose(0, 1).conj();对于实数矩阵,等价于 x.transpose(0, 1)

另请参阅

mH: 同样适用于矩阵批量的属性。

Tensor.mT

返回此张量最后两个维度转置的视图。

x.mT 等价于 x.transpose(-2, -1)

Tensor.mH

访问此属性等价于调用 adjoint()

方法表

Tensor.new_tensor返回以 data 为张量数据的新 Tensor。
Tensor.new_full返回大小为 size 且填充 fill_value 的 Tensor。
Tensor.new_empty返回大小为 size 且填充未初始化数据的 Tensor。
Tensor.new_ones返回大小为 size 且填充 1 的 Tensor。
Tensor.new_zeros返回大小为 size 且填充 0 的 Tensor。
Tensor.is_cuda如果 Tensor 存储在 GPU 上则为 True,否则为 False
Tensor.is_quantized如果 Tensor 是量化张量则为 True,否则为 False
Tensor.is_meta如果 Tensor 是元张量则为 True,否则为 False
Tensor.device返回此 Tensor 所在的 torch.device
Tensor.grad此属性默认为 None,在首次调用 backward() 计算 self 的梯度时会变为 Tensor。
Tensor.ndimdim() 的别名
Tensor.real对于复数输入张量,返回包含 self 张量实部值的新张量。
Tensor.imag返回包含 self 张量虚部值的新张量。
Tensor.nbytes如果张量不使用稀疏存储布局,则返回张量元素视图占用的字节数。
Tensor.itemsizeelement_size() 的别名
Tensor.abs参见 torch.abs()
Tensor.abs_abs() 的原位版本
Tensor.absoluteabs() 的别名
Tensor.absolute_absolute() 的原位版本,abs_() 的别名
Tensor.acos参见 torch.acos()
Tensor.acos_acos() 的原位版本
Tensor.arccos参见 torch.arccos()
Tensor.arccos_arccos() 的原位版本
Tensor.add将标量或张量加到 self 张量上。
Tensor.add_add() 的原位版本
Tensor.addbmm参见 torch.addbmm()
Tensor.addbmm_addbmm() 的原位版本
Tensor.addcdiv参见 torch.addcdiv()
Tensor.addcdiv_addcdiv() 的原位版本
Tensor.addcmul参见 torch.addcmul()
Tensor.addcmul_addcmul() 的原位版本
Tensor.addmm参见 torch.addmm()
Tensor.addmm_addmm() 的原位版本
Tensor.sspaddmm参见 torch.sspaddmm()
Tensor.addmv参见 torch.addmv()
Tensor.addmv_addmv() 的原位版本
Tensor.addr参见 torch.addr()
Tensor.addr_addr() 的原位版本
Tensor.adjointadjoint() 的别名
Tensor.allclose参见 torch.allclose()
Tensor.amax参见 torch.amax()
Tensor.amin参见 torch.amin()
Tensor.aminmax参见 torch.aminmax()
Tensor.angle参见 [torch.angle()](https://docs.pytorch.org/docs/stable/generated

张量属性

每个 torch.Tensor 都拥有 torch.dtypetorch.devicetorch.layout 属性。


torch.dtype


class torch.dtype 

torch.dtype 是一个表示 torch.Tensor 数据类型的对象。PyTorch 提供了十二种不同的数据类型:

数据类型dtype旧版构造函数
32位浮点数torch.float32torch.floattorch.*.FloatTensor
64位浮点数torch.float64torch.doubletorch.*.DoubleTensor
64位复数torch.complex64torch.cfloat
128位复数torch.complex128torch.cdouble
16位浮点数 [1]torch.float16torch.halftorch.*.HalfTensor
16位浮点数 [2]torch.bfloat16torch.*.BFloat16Tensor
8位无符号整数torch.uint8torch.*.ByteTensor
8位有符号整数torch.int8torch.*.CharTensor
16位有符号整数torch.int16torch.shorttorch.*.ShortTensor
32位有符号整数torch.int32torch.inttorch.*.IntTensor
64位有符号整数torch.int64torch.longtorch.*.LongTensor
布尔型torch.booltorch.*.BoolTensor

[1] 有时称为 binary16:使用 1 位符号、5 位指数和 10 位尾数。适用于需要高精度的场景。

[2] 有时称为 Brain 浮点数:使用 1 位符号、8 位指数和 7 位尾数。由于与 float32 具有相同的指数位数,适用于需要大范围的场景。

要判断 torch.dtype 是否为浮点数据类型,可以使用属性 is_floating_point,如果数据类型是浮点类型,则返回 True

要判断 torch.dtype 是否为复数数据类型,可以使用属性 is_complex,如果数据类型是复数类型,则返回 True

当算术运算(加、减、除、乘)的输入数据类型不同时,我们会按照以下规则找到满足条件的最小数据类型进行提升:

  • 如果标量操作数的类型属于比张量操作数更高的类别(复数 > 浮点 > 整数 > 布尔值),则提升到足以容纳该类别所有标量操作数的类型。
  • 如果零维张量操作数的类别高于有维度的操作数,则提升到足以容纳该类别所有零维张量操作数的类型。
  • 如果没有更高类别的零维操作数,则提升到足以容纳所有有维度操作数的类型。

浮点标量操作数的默认数据类型为 torch.get_default_dtype(),而整数非布尔标量操作数的默认数据类型为 torch.int64。与 NumPy 不同,我们在确定操作数的最小数据类型时不会检查具体值。目前不支持量化和复数类型的提升。

提升示例:

>>> float_tensor = torch.ones(1, dtype=torch.float)
>>> double_tensor = torch.ones(1, dtype=torch.double)
>>> complex_float_tensor = torch.ones(1, dtype=torch.complex64)
>>> complex_double_tensor = torch.ones(1, dtype=torch.complex128)
>>> int_tensor = torch.ones(1, dtype=torch.int)
>>> long_tensor = torch.ones(1, dtype=torch.long)
>>> uint_tensor = torch.ones(1, dtype=torch.uint8)
>>> bool_tensor = torch.ones(1, dtype=torch.bool)
# zero-dim tensors
>>> long_zerodim = torch.tensor(1, dtype=torch.long)
>>> int_zerodim = torch.tensor(1, dtype=torch.int)>>> torch.add(5, 5).dtype
torch.int64
# 5 is an int64, but does not have higher category than int_tensor so is not considered.
>>> (int_tensor + 5).dtype
torch.int32
>>> (int_tensor + long_zerodim).dtype
torch.int32
>>> (long_tensor + int_tensor).dtype
torch.int64
>>> (bool_tensor + long_tensor).dtype
torch.int64
>>> (bool_tensor + uint_tensor).dtype
torch.uint8
>>> (float_tensor + double_tensor).dtype
torch.float64
>>> (complex_float_tensor + complex_double_tensor).dtype
torch.complex128
>>> (bool_tensor + int_tensor).dtype
torch.int32
# Since long is a different kind than float, result dtype only needs to be large enough
# to hold the float.
>>> torch.add(long_tensor, float_tensor).dtype
torch.float32

当指定算术运算的输出张量时,我们允许将其类型转换为输出张量的数据类型,但存在以下例外情况:

  • 整型输出张量不能接受浮点型张量
  • 布尔型输出张量不能接受非布尔型张量
  • 非复数型输出张量不能接受复数型张量

类型转换示例:

# allowed:
>>> float_tensor *= float_tensor
>>> float_tensor *= int_tensor
>>> float_tensor *= uint_tensor
>>> float_tensor *= bool_tensor
>>> float_tensor *= double_tensor
>>> int_tensor *= long_tensor
>>> int_tensor *= uint_tensor
>>> uint_tensor *= int_tensor# disallowed (RuntimeError: result type can't be cast to the desired output type):
>>> int_tensor *= float_tensor
>>> bool_tensor *= int_tensor
>>> bool_tensor *= uint_tensor
>>> float_tensor *= complex_float_tensor

torch.device


class torch.device 

torch.device 是一个表示设备类型的对象,torch.Tensor 会被分配或已经分配在该设备上。

torch.device 包含一个设备类型(最常见的是 “cpu” 或 “cuda”,但也可能是 “mps”、“xpu”、“xla” 或 “meta”)以及可选的设备序号。如果未指定设备序号,该对象将始终代表该设备类型的当前设备,即使在调用 torch.cuda.set_device() 之后也是如此;例如,使用设备 'cuda' 构造的 torch.Tensor 等同于 'cuda:X',其中 X 是 torch.cuda.current_device() 的结果。

可以通过 Tensor.device 属性访问 torch.Tensor 的设备。

torch.device 可以通过字符串或字符串加设备序号来构造:

通过字符串:

>>> torch.device('cuda:0')
device(type='cuda', index=0)>>> torch.device('cpu')
device(type='cpu')>>> torch.device('mps')
device(type='mps')>>> torch.device('cuda')  # current cuda device
device(type='cuda')

通过字符串和设备序号:

>>> torch.device('cuda', 0)
device(type='cuda', index=0)>>> torch.device('mps', 0)
device(type='mps', index=0)>>> torch.device('cpu', 0)
device(type='cpu', index=0)

设备对象也可用作上下文管理器,用于更改张量分配的默认设备:

>>> with torch.device('cuda:1'):
...     r = torch.randn(2, 3)
>>> r.device
device(type='cuda', index=1)

如果向工厂函数传递了显式且非 None 的设备参数,此上下文管理器将不起作用。要全局更改默认设备,请参阅 torch.set_default_device()

警告:此函数会对每次调用 torch API 的 Python 操作(不仅限于工厂函数)产生轻微性能开销。如果这给您带来问题,请在 https://github.com/pytorch/pytorch/issues/92701 发表评论。

注意:函数中的 torch.device 参数通常可以用字符串替代,这有助于快速原型开发代码。


>>> # Example of a function that takes in a torch.device
>>> cuda1 = torch.device('cuda:1')
>>> torch.randn((2,3), device=cuda1)

>>> # You can substitute the torch.device with a string
>>> torch.randn((2,3), device='cuda:1')

注意:由于历史遗留原因,可以通过单个设备序号来构造设备,该序号会被视为当前加速器类型。

这与 Tensor.get_device() 的行为一致——该方法会返回设备张量的序号,但不支持CPU张量。


>>> torch.device(1)
device(type='cuda', index=1)

注意:接收设备参数的方法通常支持(格式正确的)字符串或(旧版)整数设备序号,以下写法都是等效的:

>>> torch.randn((2,3), device=torch.device('cuda:1'))
>>> torch.randn((2,3), device='cuda:1')
>>> torch.randn((2,3), device=1)  # legacy

注意:张量不会自动在设备间移动,需要用户显式调用。标量张量(tensor.dim()==0的情况)是此规则唯一的例外——当需要时它们会自动从CPU转移到GPU,因为该操作可以"零成本"完成。


示例:

>>> # two scalars
>>> torch.ones(()) + torch.ones(()).cuda()  # OK, scalar auto-transferred from CPU to GPU
>>> torch.ones(()).cuda() + torch.ones(())  # OK, scalar auto-transferred from CPU to GPU

>>> # one scalar (CPU), one vector (GPU)
>>> torch.ones(()) + torch.ones(1).cuda()  # OK, scalar auto-transferred from CPU to GPU
>>> torch.ones(1).cuda() + torch.ones(())  # OK, scalar auto-transferred from CPU to GPU

>>> # one scalar (GPU), one vector (CPU)
>>> torch.ones(()).cuda() + torch.ones(1)  # Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU
>>> torch.ones(1) + torch.ones(()).cuda()  # Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU

torch.layout


class torch.layout 

警告:torch.layout 类目前处于测试阶段,后续可能会发生变化。

torch.layout 是一个表示 torch.Tensor 内存布局的对象。目前我们支持 torch.strided(密集张量),并对 torch.sparse_coo(稀疏 COO 张量)提供测试版支持。

torch.strided 表示密集张量,这是最常用的内存布局方式。每个跨步张量都有一个关联的 torch.Storage 对象用于存储数据。这些张量提供了存储的多维跨步视图。跨步是一个整数列表:第 k 个跨步表示在张量的第 k 维中,从一个元素移动到下一个元素所需的内存跳跃量。这个概念使得许多张量操作能够高效执行。


示例:

>>> x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
>>> x.stride()
(5, 1)>>> x.t().stride()
(1, 5)

有关 torch.sparse_coo 张量的更多信息,请参阅 torch.sparse。


torch.memory_format


class torch.memory_format 

torch.memory_format 是一个表示内存格式的对象,用于描述 torch.Tensor 当前或将要分配的内存布局。

可能的取值包括:

  • torch.contiguous_format

张量当前或将要分配在密集且无重叠的内存中。其步长(strides)以递减顺序表示。

  • torch.channels_last

张量当前或将要分配在密集且无重叠的内存中。其步长遵循 strides[0] strides[2] strides[3] strides[1] == 1 的顺序,即 NHWC 格式。

  • torch.channels_last_3d

张量当前或将要分配在密集且无重叠的内存中。其步长遵循 strides[0] strides[2] strides[3] strides[4] strides[1] == 1 的顺序,即 NDHWC 格式。

  • torch.preserve_format

用于 clone 等函数中,以保留输入张量的内存格式。如果输入张量分配在密集且无重叠的内存中,输出张量的步长将从输入张量复制。否则,输出张量的步长将遵循 torch.contiguous_format


张量视图

PyTorch 允许一个张量作为现有张量的视图(View)。视图张量与其基础张量共享相同底层数据。支持视图可以避免显式数据拷贝,从而实现快速且内存高效的形状变换、切片和逐元素操作。

例如,要获取现有张量t的视图,可以调用t.view(...)方法。


>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr()  # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)

由于视图与基础张量共享底层数据,当修改视图中的数据时,基础张量也会同步更新。

PyTorch操作通常返回一个新张量作为输出,例如add()。但对于视图操作,输出会作为输入张量的视图以避免不必要的数据拷贝。

创建视图时不会发生数据移动,视图张量仅改变了对同一数据的解释方式。对连续张量取视图可能会产生非连续张量。

用户需特别注意,因为连续性可能隐式影响性能。transpose()就是典型示例。


>>> base = torch.tensor([[0, 1],[2, 3]])
>>> base.is_contiguous()
True
>>> t = base.transpose(0, 1)  # `t` is a view of `base`. No data movement happened here.
# View tensors might be non-contiguous.
>>> t.is_contiguous()
False
# To get a contiguous tensor, call `.contiguous()` to enforce
# copying data when `t` is not contiguous.
>>> c = t.contiguous()

以下是PyTorch中视图操作(view ops)的完整参考列表:

  • 基础切片和索引操作,例如tensor[0, 2:, 1:7:2]会返回基础tensor的视图(注意事项见下文)
  • adjoint()
  • as_strided()
  • detach()
  • diagonal()
  • expand()
  • expand_as()
  • movedim()
  • narrow()
  • permute()
  • select()
  • squeeze()
  • transpose()
  • t()
  • T
  • H
  • mT
  • mH
  • real
  • imag
  • view_as_real()
  • unflatten()
  • unfold()
  • unsqueeze()
  • view()
  • view_as()
  • unbind()
  • split()
  • hsplit()
  • vsplit()
  • tensor_split()
  • split_with_sizes()
  • swapaxes()
  • swapdims()
  • chunk()
  • indices()(仅稀疏张量)
  • values()(仅稀疏张量)

注意事项:
当通过索引访问张量内容时,PyTorch遵循NumPy的行为规范:基础索引返回视图,而高级索引返回副本。无论是基础索引还是高级索引进行的赋值操作都是就地(in-place)执行的。更多示例可参考NumPy索引文档。

需要特别说明的几个操作:

  • reshape()reshape_as()flatten()可能返回视图或新张量,用户代码不应依赖其返回类型
  • contiguous()在输入张量已连续时返回其自身,否则会通过复制数据返回新的连续张量

如需深入了解PyTorch内部实现机制,请参阅ezyang的PyTorch内部原理博客文章。


2025-05-10(六)

http://www.xdnf.cn/news/381061.html

相关文章:

  • 【LangChain全景指南】构建下一代AI应用的开发框架
  • 数字相机的快门结构
  • not a genuine st device abort connection的问题
  • 实现三个采集板数据传送到一个显示屏的方案
  • null 的安全操作 vs 危险操作
  • Linux环境下基于Ncurses开发贪吃蛇小游戏
  • Java 内存模型 JMM
  • Edububtu 系统详解
  • Exploring Temporal Event Cues for Dense Video Captioning in Cyclic Co-Learning
  • 一个好用的快速学习的网站
  • python打卡day21
  • JavaScript基础-作用域概述
  • JDK10新特性
  • Apache Shiro 1.2.4 反序列化漏洞(CVE-2016-4437)
  • 二进制与十六进制数据转换:原理、实现与应用
  • DAY 21 常见的降维算法
  • 简述Web和HTTP
  • centos7.9上安装 freecad 指定安装位置
  • WinCC V7.2到V8.0与S71200/1500系列连接通讯教程以及避坑点
  • 码蹄集——向下取整(求立方根)、整理玩具、三角形斜边、完全平方数、个人所得税
  • MQTT协议介绍
  • 数据结构算法习题通关:树遍历 / 哈夫曼 / 拓扑 / 哈希 / Dijkstra 全解析
  • Python中的列表list使用详解
  • 重复的子字符串
  • 【ts】defineProps数组的类型声明
  • 人工智能100问☞第19问:什么是专家系统?
  • 自定义类型-结构体(二)
  • 基于ssm的超市库存商品管理系统(全套)
  • Vue.js框架的优缺点
  • 2025年PMP 学习六 -第5章 项目范围管理 (5.1,5.2,5.3)