pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例
pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例
- 一、函数签名与参数说明
- 二、基础用法
- 1. 一维张量累积乘积
- 2. 二维张量按行/按列累积
- 三、`dtype` 参数:避免整数溢出与提升精度
- 四、典型应用场景
- 1. 几何序列生成
- 2. 概率分布的累积乘积
- 3. 模型门控或权重衰减
- 五、进阶示例:预分配 `out` 张量
- 六、小结
在深度学习与科学计算中,往往需要沿某个维度追踪“前面所有元素的乘积”,比如几何序列计算、概率分布构建、模型门控/权重衰减等场景。PyTorch 提供的 torch.cumprod
函数可以一行代码搞定这一需求。本文将从函数签名、参数含义、基础用法,到进阶示例、典型应用场景,为你带来最全面的讲解,并附上丰富示例助你快速上手。
一、函数签名与参数说明
torch.cumprod(input: Tensor,dim: int,*,dtype: Optional[torch.dtype] = None,out: Optional[Tensor] = None
) → Tensor
input
:任意维度的输入张量。dim
:指定沿哪个维度做累积乘积(0
表示第一个维度,以此类推)。dtype
(可选):输出张量的数据类型。如果原张量为整数且会溢出,可通过将其提升到更宽数据类型来避免溢出。out
(可选):预先分配好的张量,用于存储输出,避免额外内存分配。
二、基础用法
1. 一维张量累积乘积
import torchx = torch.tensor([1, 2, 3, 4])
y = torch.cumprod(x, dim=0)
print(y) # tensor([ 1, 2, 6, 24])
y[0] = 1
y[1] = 1 * 2 = 2
y[2] = 1 * 2 * 3 = 6
y[3] = 1 * 2 * 3 * 4 = 24
2. 二维张量按行/按列累积
x2 = torch.tensor([[1, 2, 3],[4, 5, 6]])
# 沿行(dim=1)累积
row_prod = torch.cumprod(x2, dim=1)
print(row_prod)
# tensor([[ 1, 2, 6],
# [ 4, 20, 120]])# 沿列(dim=0)累积
col_prod = torch.cumprod(x2, dim=0)
print(col_prod)
# tensor([[1, 2, 3],
# [4, 10, 18]])
三、dtype
参数:避免整数溢出与提升精度
当 input
为大整数且乘积超出类型范围时,会导致溢出。此时可指定更宽的数据类型:
x_int = torch.tensor([1000, 1000, 1000], dtype=torch.int32)
# 默认 int32 会溢出
print(torch.cumprod(x_int, dim=0))
# tensor([1000, -727, -728], dtype=torch.int32)# 改为 int64 避免溢出
print(torch.cumprod(x_int, dim=0, dtype=torch.int64))
# tensor([ 1000, 1000000, 1000000000])
四、典型应用场景
1. 几何序列生成
几何序列 a , a r , a r 2 , … a, ar, ar^2, … a,ar,ar2,… 可用累积乘积实现:
a, r, n = 2.0, 0.5, 5
ratios = torch.full((n,), r) # [r, r, r, r, r]
geom = a * torch.cumprod(ratios, dim=0)
print(geom)
# tensor([1.0000, 0.5000, 0.2500, 0.1250, 0.0625])
2. 概率分布的累积乘积
在构建离散分布的乘积模型时,用累乘来得到联合概率:
probs = torch.tensor([0.2, 0.3, 0.5])
# 标准化(确保和为1)
probs = probs / probs.sum()
# 获取依次乘积(注意:乘积非累加,因此并非 CDF)
joint = torch.cumprod(probs, dim=0)
print(joint)
# tensor([0.2000, 0.0600, 0.0300])
3. 模型门控或权重衰减
在 RNN、Transformer 等模型中,若需要对前 n 层或时间步做指数衰减,可用累积乘积计算衰减系数:
decay_rates = torch.linspace(0.9, 0.5, steps=4) # 每层不同衰减率
coeffs = torch.cumprod(decay_rates, dim=0) # 累积得到层间总衰减
print(coeffs)
# tensor([0.9000, 0.7200, 0.5040, 0.2520])
五、进阶示例:预分配 out
张量
为了在高性能场景下避免额外内存分配,可以先分配好输出张量,再将结果写入:
x = torch.arange(1, 1001, dtype=torch.float32)
out = torch.empty_like(x)
torch.cumprod(x, dim=0, out=out)
print(out[:5]) # tensor([1., 2., 6., 24., 120.])
六、小结
-
功能:
torch.cumprod
沿指定维度计算输入张量的累计乘积,返回新张量。 -
关键参数:
dim
:累积轴;dtype
:避免整数溢出/提升精度;out
:预分配输出提高性能。
-
常见应用:
- 几何序列生成;
- 概率分布乘积;
- 模型门控/权重衰减;
- 其它需要“前缀乘积”场景。