当前位置: 首页 > web >正文

PyTorch 中cumprod函数计算张量沿指定维度的累积乘积详解和代码示例

torch.cumprod 是 PyTorch 中用于 计算张量沿指定维度的累积乘积(cumulative product) 的函数。


1、函数原型

torch.cumprod(input, dim, *, dtype=None, out=None) → Tensor

参数说明:

参数说明
input输入张量
dim累积乘积的维度
dtype可选:指定输出类型(默认与输入类型相同)
out可选:输出张量(用于 inplace)

2、功能说明

对于指定维度 dim,返回一个张量,其中每个元素是该位置及之前所有元素的乘积。


3、示例代码

示例 1:一维张量

import torchx = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
y = torch.cumprod(x, dim=0)
print("输入:", x)
print("累积乘积:", y)

输出:

输入: tensor([1., 2., 3., 4.])
累积乘积: tensor([ 1.,  2.,  6., 24.])

示例 2:二维张量,沿 dim=0(列)

x = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]], dtype=torch.float32)y = torch.cumprod(x, dim=0)
print(y)

输出:

tensor([[  1.,   2.,   3.],[  4.,  10.,  18.],[ 28.,  80., 162.]])

计算过程解释(逐列):

  • 第 1 列: [1, 4, 7][1, 1×4=4, 4×7=28]
  • 第 2 列: [2, 5, 8][2, 2×5=10, 10×8=80]
  • 第 3 列: [3, 6, 9][3, 3×6=18, 18×9=162]

示例 3:使用 dtype 强制类型

x = torch.tensor([1, 2, 3], dtype=torch.int32)
y = torch.cumprod(x, dim=0, dtype=torch.float32)
print(y)

输出:

tensor([1., 2., 6.])

4、综合应用示例

下面是一个完整的示例,展示了 torch.cumprod 在神经网络训练中如何用于 前向传播中累积权重乘积的计算。这种用法常见于:

  • 路径权重乘积模型(Path Weight Product Models)
  • 自定义神经网络结构中累积乘积(如神经ODE、概率模型)

4.1、示例背景

假设我们有一个网络结构:每一层只有一个权重因子,我们要计算所有权重乘积作为 forward 输出的一部分。


4.2、示例代码:累积权重乘积的自定义网络

import torch
import torch.nn as nnclass CumprodNet(nn.Module):def __init__(self, num_layers):super(CumprodNet, self).__init__()# 每层一个标量权重参数,初始化为 0.9 左右self.weights = nn.Parameter(torch.rand(num_layers) * 0.2 + 0.9)def forward(self, x):# 假设 x 是输入标量或批量张量# 计算权重的累积乘积path_weights = torch.cumprod(self.weights, dim=0)# 将每层的路径加权输出加总outputs = torch.stack([x * pw for pw in path_weights], dim=0)return outputs.sum(dim=0), path_weights  # 返回结果和路径乘积向量# 初始化模型
model = CumprodNet(num_layers=4)# 输入张量(可批量)
x = torch.tensor([1.0], requires_grad=True)# 前向传播
output, path_weights = model(x)# 打印结果
print("权重参数:", model.weights.data)
print("累积乘积:", path_weights)
print("最终输出:", output)# 反向传播
output.backward()
print("输入梯度:", x.grad)

4.3、输出说明(示例)

假设 self.weights = [0.91, 0.95, 1.01, 1.05]

cumprod 将计算:

[0.91,0.91 × 0.95 = 0.8645,0.8645 × 1.01 = 0.8731,0.8731 × 1.05 ≈ 0.9167]

然后每个都乘上输入 x,最后加总作为最终输出。


4.4、应用场景

  1. 路径加权神经网络
  2. 可学习的指数衰减控制
  3. 自定义 RNN、深层残差控制器中的动态路径参数建模
  4. 强化学习中的路径概率分布建模(Policy Gradient)

5、注意事项

  • cumprod 会在指定维度上,按顺序相乘;
  • 输入中如果有 0,后续的所有乘积都会变为 0
  • 常用于概率连乘、对数空间建模前的准备步骤(比如前向链式法则)。

6、与相关函数对比

函数功能
torch.cumsum累加和
torch.cumprod累乘积
torch.prod所有元素乘积(非逐步)
torch.cummax / cummin累积最大/最小值
http://www.xdnf.cn/news/12675.html

相关文章:

  • Oracle 19c RAC集群ADG搭建
  • MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)
  • 计算机是如何⼯作的
  • 408第一季 - 数据结构 - 树与二叉树II
  • 《Brief Bioinform》: 鼠脑单细胞与Stereo-seq数据整合算法评估
  • 【Java实例-英雄对战】Java战斗之旅,既分胜负也决生死
  • 台式机电脑CPU天梯图2025年6月份更新:CPU选购指南及推荐
  • Canal环境搭建并实现和ES数据同步
  • App Search 和 Workplace Search 独立产品现已弃用
  • Cursor实现用excel数据填充word模版的方法
  • Fetch与Axios:区别、联系、优缺点及使用差异
  • 使用 C/C++ 和 OpenCV 提取图像的感兴趣区域 (ROI)
  • vue3+dify从零手撸AI对话系统
  • JavaWeb的一些基础技术
  • 在Ubuntu上使用 dd 工具制作U盘启动盘
  • 使用Transformer模型进行时间序列预测的完整解决方案,满足预测误差≤1.5%和注意力权重可视化的要求
  • GitHub 趋势日报 (2025年06月06日)
  • 2025年- H76-Lc184--55.跳跃游戏(贪心)--Java版
  • 有没有 MariaDB 5.5.56 对应 MySQL CONNECTION_CONTROL 插件
  • 信息最大化(Information Maximization)
  • Go语言进阶④:Go的数据结构和Java的有啥不一样
  • 光学字符识别(OCR)理论概述与实践教程
  • 动目标显示处理解析一(脉冲对消器)
  • Ubuntu 配置使用 zsh + 插件配置 + oh-my-zsh 美化过程
  • 前沿论文汇总(机器学习/深度学习/大模型/搜广推/自然语言处理)
  • 数据类型 -- 字符
  • SQL字符串截取函数全解析:LEFT、RIGHT、SUBSTRING 实战指南
  • 如何使用Jmeter进行压力测试?
  • MySQL-运维篇
  • 隐私计算时代B端页面安全设计:数据脱敏与权限体系升级路径