当前位置: 首页 > ops >正文

PyTorch自动求导

1. 计算图构建过程

x = torch.ones(5, requires_grad=True)  # 定义叶子节点,启用梯度跟踪
y = x + 2                             # 加法操作,生成中间节点 y
z = y * y * 3                         # 平方与乘法操作,生成中间节点 z
out = z.mean()                        # 标量输出(损失函数)
  • 动态计算图构建​:

    每行代码触发一个操作,PyTorch 动态记录操作依赖关系,生成有向无环图(DAG):

    x → (Add) → y → (Pow + Mul) → z → (Mean) → out

    节点类型:

    • 叶子节点​:用户直接创建的 xx.is_leaf = True)。
    • 非叶子节点​:y, z, out由运算生成(grad_fn属性记录操作类型)
  • 梯度跟踪机制​:

    设置 requires_grad=True后,所有依赖 x的中间节点自动继承此属性(如 y.requires_grad=True


2. 反向传播与梯度计算

out.backward()  # 触发反向传播
  • 反向传播流程​:
    1. 1.out开始反向遍历​:因 out是标量(shape=()),无需额外指定梯度权重
    2. 2.

      链式法则应用​:

      • out = z.mean()→ ∂zi​∂out​=51​(z有 5 个元素)。
      • z = 3y^2→ ∂yi​∂zi​​=6yi​。
      • y = x + 2→ ∂xi​∂yi​​=1
    3. 3.​梯度计算​:

      ∂xi​∂out​=∂zi​∂out​⋅∂yi​∂zi​​⋅∂xi​∂yi​​=51​⋅6yi​⋅1=56​(xi​+2)。

  • •​梯度存储​:

    结果存入叶子节点 x.grad,非叶子节点(如 y, z)的梯度默认不保留以节省内存


3. 梯度结果验证

print(f"x 的梯度: {x.grad}")  # 输出:tensor([3.6000, 3.6000, 3.6000, 3.6000, 3.6000])
  • •​数学推导​:

    代入 xi​=1:

    ∂xi​∂out​=56​(1+2)=518​=3.6。

    与代码输出一致,验证了链式法则的正确性


4. 梯度累积问题

  • •​默认行为​:

    backward()计算的梯度会累加x.grad。若多次执行 out.backward(),梯度将叠加(如运行两次后 x.grad变为 [7.2, 7.2, ...]

  • 解决方案​:

    训练循环中需在每次反向传播前调用 x.grad.zero_()optimizer.zero_grad()清零梯度


关键概念总结

概念

说明

代码示例

叶子节点

用户直接创建的张量,梯度计算终点

x = torch.ones(5, requires_grad=True)

动态计算图

运行时动态构建的操作依赖图,反向传播后自动释放

y = x + 2生成 AddBackward节点

非标量反向传播

out非标量(如向量),需传入 gradient参数作为权重矩阵

z.backward(torch.ones_like(z))

梯度保留

设置 retain_graph=True可保留计算图,支持多次反向传播

out.backward(retain_graph=True)


提示​:理解计算图结构是调试自动求导的关键。可通过 print(y.grad_fn)查看操作类型(如输出 <AddBackward0>),或使用 torchviz库可视化计算图

http://www.xdnf.cn/news/18190.html

相关文章:

  • OpenHarmony之打造全场景智联基座的“分布式星链 ”WLAN子系统
  • Java试题-选择题(11)
  • Consul- acl机制!
  • 【Pycharm虚拟环境中安装Homebrew,会到系统中去吗】
  • 【牛客刷题】岛屿数量问题:BFS与DFS解法深度解析
  • Java NIO (New I/O) 深度解析
  • windows电脑对于dell(戴尔)台式的安装,与创建索引盘,系统迁移到新硬盘
  • Nacos-8--分析一下nacos中的AP和CP模式
  • 从现场到云端的“通用语”:Kepware 在工业互联中的角色、使用方法与本土厂商(以胡工科技为例)的差异与优势
  • vLLM加载lora
  • 【MATLAB例程】水下机器人AUV的长基线定位,适用于三维环境,EKF融合长基线和IMU数据,锚点数量可自适应,附下载链接
  • (一)八股(数据库/MQ/缓存)
  • 在Ubuntu上安装并使用Vue2的基本教程
  • week2-[一维数组]最大元素
  • 监督分类——最小距离分类、最大似然分类、支持向量机
  • 第一章 认识单片机
  • 一个基于前端技术的小狗寿命阶段计算网站,帮助用户了解狗狗在不同年龄阶段的特点和需求。
  • 芯显 15.6寸G156HAE02.0 FHD 宽温液晶模组技术档案
  • Spring Boot应用实现图片资源服务
  • 【实时Linux实战系列】基于实时Linux的物联网系统设计
  • [嵌入式embed][Qt]一个新手Qt开发环境5.12.12
  • VS Code 终端完全指南
  • 机器学习中的「损失函数」:模型优化的核心标尺
  • 2025.8.19总结
  • Qt猜数字游戏项目开发教程 - 从零开始构建趣味小游戏
  • BCT8937A Class T Audio Amplifier
  • GPFS不同存储方式的优劣
  • 【数据结构】使用队列解决二叉树问题
  • 4.pod生命周期和健康检测以及使用kubectl管理Kubernetes容器平台
  • B站 韩顺平 笔记 (Day 23)