当前位置: 首页 > news >正文

pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例

pytorch小记(二十二):全面解读 PyTorch 的 `torch.cumprod`——累积乘积详解与实战示例

    • 一、函数签名与参数说明
    • 二、基础用法
      • 1. 一维张量累积乘积
      • 2. 二维张量按行/按列累积
    • 三、`dtype` 参数:避免整数溢出与提升精度
    • 四、典型应用场景
      • 1. 几何序列生成
      • 2. 概率分布的累积乘积
      • 3. 模型门控或权重衰减
    • 五、进阶示例:预分配 `out` 张量
    • 六、小结


在深度学习与科学计算中,往往需要沿某个维度追踪“前面所有元素的乘积”,比如几何序列计算、概率分布构建、模型门控/权重衰减等场景。PyTorch 提供的 torch.cumprod 函数可以一行代码搞定这一需求。本文将从函数签名、参数含义、基础用法,到进阶示例、典型应用场景,为你带来最全面的讲解,并附上丰富示例助你快速上手。


一、函数签名与参数说明

torch.cumprod(input: Tensor,dim: int,*,dtype: Optional[torch.dtype] = None,out: Optional[Tensor] = None
) → Tensor
  • input:任意维度的输入张量。
  • dim:指定沿哪个维度做累积乘积(0 表示第一个维度,以此类推)。
  • dtype(可选):输出张量的数据类型。如果原张量为整数且会溢出,可通过将其提升到更宽数据类型来避免溢出。
  • out(可选):预先分配好的张量,用于存储输出,避免额外内存分配。

二、基础用法

1. 一维张量累积乘积

import torchx = torch.tensor([1, 2, 3, 4])
y = torch.cumprod(x, dim=0)
print(y)  # tensor([ 1,  2,  6, 24])
  • y[0] = 1
  • y[1] = 1 * 2 = 2
  • y[2] = 1 * 2 * 3 = 6
  • y[3] = 1 * 2 * 3 * 4 = 24

2. 二维张量按行/按列累积

x2 = torch.tensor([[1, 2, 3],[4, 5, 6]])
# 沿行(dim=1)累积
row_prod = torch.cumprod(x2, dim=1)
print(row_prod)
# tensor([[  1,   2,   6],
#         [  4,  20, 120]])# 沿列(dim=0)累积
col_prod = torch.cumprod(x2, dim=0)
print(col_prod)
# tensor([[1, 2,  3],
#         [4, 10, 18]])

三、dtype 参数:避免整数溢出与提升精度

input 为大整数且乘积超出类型范围时,会导致溢出。此时可指定更宽的数据类型:

x_int = torch.tensor([1000, 1000, 1000], dtype=torch.int32)
# 默认 int32 会溢出
print(torch.cumprod(x_int, dim=0))
# tensor([1000,  -727,  -728], dtype=torch.int32)# 改为 int64 避免溢出
print(torch.cumprod(x_int, dim=0, dtype=torch.int64))
# tensor([      1000,    1000000, 1000000000])

四、典型应用场景

1. 几何序列生成

几何序列 a , a r , a r 2 , … a, ar, ar^2, … a,ar,ar2, 可用累积乘积实现:

a, r, n = 2.0, 0.5, 5
ratios = torch.full((n,), r)               # [r, r, r, r, r]
geom = a * torch.cumprod(ratios, dim=0)
print(geom)
# tensor([1.0000, 0.5000, 0.2500, 0.1250, 0.0625])

2. 概率分布的累积乘积

在构建离散分布的乘积模型时,用累乘来得到联合概率:

probs = torch.tensor([0.2, 0.3, 0.5])
# 标准化(确保和为1)
probs = probs / probs.sum()
# 获取依次乘积(注意:乘积非累加,因此并非 CDF)
joint = torch.cumprod(probs, dim=0)
print(joint)
# tensor([0.2000, 0.0600, 0.0300])

3. 模型门控或权重衰减

在 RNN、Transformer 等模型中,若需要对前 n 层或时间步做指数衰减,可用累积乘积计算衰减系数:

decay_rates = torch.linspace(0.9, 0.5, steps=4)  # 每层不同衰减率
coeffs = torch.cumprod(decay_rates, dim=0)      # 累积得到层间总衰减
print(coeffs)
# tensor([0.9000, 0.7200, 0.5040, 0.2520])

五、进阶示例:预分配 out 张量

为了在高性能场景下避免额外内存分配,可以先分配好输出张量,再将结果写入:

x = torch.arange(1, 1001, dtype=torch.float32)
out = torch.empty_like(x)
torch.cumprod(x, dim=0, out=out)
print(out[:5])  # tensor([1., 2., 6., 24., 120.])

六、小结

  • 功能torch.cumprod 沿指定维度计算输入张量的累计乘积,返回新张量。

  • 关键参数

    • dim:累积轴;
    • dtype:避免整数溢出/提升精度;
    • out:预分配输出提高性能。
  • 常见应用

    1. 几何序列生成;
    2. 概率分布乘积;
    3. 模型门控/权重衰减;
    4. 其它需要“前缀乘积”场景。
http://www.xdnf.cn/news/497287.html

相关文章:

  • Java求职面试:从核心技术到大数据与AI的场景应用
  • [Android] 安卓彩蛋:Easter Eggs v3.4.0
  • 第五项修炼:打造学习型组织
  • 前端基础之CSS
  • 大语言模型 11 - 从0开始训练GPT 0.25B参数量 MiniMind2 准备数据与训练模型 DPO直接偏好优化
  • 【诊所电子处方专用软件】佳易王个体诊所门诊电子处方开单管理系统:零售药店电子处方服务系统#操作简单#诊所软件教程#药房划价
  • Java 快速转 C# 教程
  • 30、WebAssembly:古代魔法——React 19 性能优化
  • 手撕四种常用设计模式(工厂,策略,代理,单例)
  • 设计模式Java
  • IDEA反斜杠路径不会显示JUnit运行的工作目录配置问题
  • 信奥赛-刷题笔记-栈篇-T2-P1981表达式求值0517
  • 在Maven中替换文件内容的插件和方法
  • 防范Java应用中的恶意文件上传:确保服务器的安全性
  • 【机器人】复现 WMNav 具身导航 | 将VLM集成到世界模型中
  • 结构化思考力_第一章_明确理念打基础
  • 西门子 Teamcenter13 Eclipse RCP 开发 1.2 工具栏 开关按钮
  • WPS JS宏实现去掉文档中的所有空行
  • 深入解析Spring Boot与Redis集成:高效缓存实践
  • Ansible模块——设置软件仓库和安装软件包
  • Python海龟绘图(Turtle Graphics)核心函数和关键要点
  • 【Linux网络】内网穿透
  • 当语言模型学会犯错和改正:搜索流(SoS)方法解析
  • 兰亭妙微:用系统化思维重构智能座舱 UI 体验
  • 【Redis】零碎知识点(易忘 / 易错)总结回顾
  • linux标准库头文件解析
  • Go语言实现链式调用
  • vscode用python开发maya联动调试设置
  • 游戏引擎学习第288天:继续完成Brains
  • 98. 验证二叉搜索树