当前位置: 首页 > java >正文

2025-04-24 Python深度学习4—— 计算图与动态图机制

文章目录

  • 1 计算图
  • 2 叶子结点
  • 2 自动求导
    • 2.1 示例
    • 2.2 权重求导
  • 4 梯度函数

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

1 计算图

​ 计算图是用来描述运算的有向无环图,由节点(Node)和边(Edge)组成。

  • 结点表示数据(如向量,矩阵,张量)。
  • 边表示运算(如加法、乘法、激活函数)。

​ 表达式 y = ( x + w ) ∗ ( w + 1 ) y = (x + w) * (w + 1) y=(x+w)(w+1) 可拆解为:

  1. a = x + w a=x +w a=x+w
  2. b = w + 1 b=w+1 b=w+1
  3. y = a ∗ b y=a *b y=ab
image-20250423222053549

​ 在动态图中,每一步操作即时生成计算节点,可灵活插入调试代码。

特性动态图(PyTorch)静态图(TensorFlow 1.x)
搭建方式运算与建图同时进行(即时执行)先定义完整计算图,再执行(延迟执行)
灵活性高,可随时修改计算流程低,计算图固定后不可更改
调试难度易调试(逐行执行)难调试(需先构建完整图)
性能优化运行时优化较少可预先优化计算路径(如算子融合)

2 叶子结点

  • 叶子结点:用户创建的结点称为叶子结点(如 x 与 w),是计算图的根基。
    • is_leaf:指示张量是否为叶子结点。

代码示例

import torchw = torch.tensor(1., requires_grad=True)
x = torch.tensor(2., requires_grad=True)
a = w + x
b = w + 1
y = a * bw.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf
image-20250423223851864

2 自动求导

​ 自动梯度计算:通过构建计算图(Computational Graph)自动计算张量的梯度,无需手动推导。

tensor.backward()

image-20250423225239909
  • gradient:多梯度权重。
  • retain_graph:保留计算图(默认释放,用于多次反向传播)。
  • create_graph:创建导数计算图(用于高阶求导)。
  • inputs:梯度将被累积到 .grad 中的输入,所有其他张量将被忽略。如果没有提供,则梯度将被累积到用于计算:attr:tensors 的所有叶子张量

2.1 示例

​ 例如,当 x = 2 , w = 1 x=2,w=1 x=2,w=1
y = x w 2 + ( x + 1 ) w + x y=xw^2+(x+1)w+x y=xw2+(x+1)w+x

$$ \begin{aligned}\frac{\partial y}{\partial w}&=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}\\&=b*1+a*1\\&=(w+1)+(x+w)\\&=2*w+1\\&=2*1+2+1\\&=5\end{aligned} $$

代码示例

import torchw = torch.tensor(1., requires_grad=True)
x = torch.tensor(2., requires_grad=True)
a = w + x
b = w + 1
y = a * b
y.backward()w.grad, x.grad, a.grad, b.grad, y.grad
image-20250423223915053

注意

  • 反向传播后,非叶子节点(如 A, B, Y)的梯度默认被释放以节省内存。
  • 使用 retain_grad() 保留非叶子节点梯度。

代码示例

import torchw = torch.tensor(1., requires_grad=True)
x = torch.tensor(2., requires_grad=True)
a = w + x
a.retain_grad()  # 保留 a 的梯度
b = w + 1
y = a * b
y.backward()w.grad, x.grad, a.grad, b.grad, y.grad
image-20250423224134891

2.2 权重求导

代码示例

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = w + x
b = w + 1y0 = a * b  # (x + w) * (x + 1)   dy0/dw = 5
y1 = a + b  # (x + w) + (x + 1)   dy1/dw = 2loss = torch.cat([y0, y1], dim=0)
grad_tensors = torch.tensor([1., 1.])loss.backward(gradient=grad_tensors)  # [1., 1.] * [5., 2.]w.grad, x.grad
image-20250423230835262

​ 将权重改为 [1., 2.]

grad_tensors = torch.tensor([1., 2.])loss.backward(gradient=grad_tensors)  # [1., 2.] * [5., 2.]w.grad, x.grad
image-20250423232910231

torch.autograd.grad()

image-20250424153543391

​ 功能:求取梯度。

  • outputs:用于求导的张量,如 loss。
  • inputs:需要梯度的张量。
  • grad_outputs:多梯度权重。
  • retain_graph:保存计算图。
  • create_graph:创建导数计算图,用于高阶求导。
  • only_inputs:当前已废弃(deprecated),会被直接忽略。
  • allow_unused:控制是否允许输入中存在未被使用的变量。
    • 如果设为 False(默认值取决于 materialize_grads),当输入的某些变量在前向计算中未被使用时,会直接报错(因为这些变量的梯度始终为零)。
    • 如果设为 True,则跳过这些未使用的变量,不会报错,其梯度返回 None
  • is_grads_batched:是否将 grad_outputs 的第一维度视为批处理维度。如果设为 True,会使用 PyTorch 的 vmap 原型功能,将 grad_outputs 中的每个向量视为一个批处理样本,一次性计算整个批量的向量-雅可比积(而非手动循环计算)。
  • materialize_grads:控制是否将未使用输入的梯度显式置零(而非返回 None)。
    • 如果设为 True,未被使用的输入的梯度会返回零张量;若设为 False,则返回 None
    • 如果 materialize_grads=Trueallow_unused=False,会直接报错(因为逻辑冲突)。

代码示例

x = torch.tensor([3.], requires_grad=True)  # x = 3
y = x * x  # y = x^2grad_1 = torch.autograd.grad(y, x, create_graph=True)  # 1 阶导:y = 2x
grad_2 = torch.autograd.grad(grad_1, x)  # 2 阶导:y = 2grad_1, grad_2
image-20250424153514587

autograd 小贴士

  1. 梯度不自动清零。

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)for i in range(3):a = w + xb = w + 1y = a * by.backward()print(w.grad, x.grad)# w.grad.zero_()  # 梯度不自动清零,则会累加# x.grad.zero_()
    
    image-20250424154734659
  2. 依赖于 requires._grad = True 叶子结点的结点,requires._grad 默认为 True

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)a = w + x
    b = w + 1
    y = a * ba.requires_grad, b.requires_grad, y.requires_grad
    
    image-20250424154932262
  3. 叶子结点不可执行 in-place 操作原地修改数据,否则自动求导结果会出现错误。

    a = torch.ones((1, ))
    print(id(a), a)a = a + 1
    print(id(a), a)a += 1  # in-place 操作原地修改数据
    print(id(a), a)
    
    image-20250424155125740

4 梯度函数

​ grad_fn:记录创建张量时的运算方法,用于反向传播时的求导规则。

  • y.grad_fn=<MulBackward0>
  • a.grad_fn=<AddBackward0>
  • b.grad_fn=<AddBackward0>

代码示例

w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn
image-20250423224551637

http://www.xdnf.cn/news/1721.html

相关文章:

  • 极狐GitLab 如何 cherry-pick 变更?
  • STM32移植最新版FATFS
  • Godot开发2D冒险游戏——第二节:主角光环整起来!
  • C# new Bitmap(32043, 32043, PixelFormat.Format32bppArgb)报错:参数无效,如何将图像分块化处理?
  • STM32F103_HAL库+寄存器学习笔记20 - CAN发送中断+ringbuffer + CAN空闲接收中断+接收所有CAN报文+ringbuffer
  • Python爬虫去重策略:增量爬取与历史数据比对
  • VulnHub-DC-2靶机渗透教程
  • zip是 Python 中 `zip` 函数的一个用法
  • 数模学习:一,层次分析法
  • flutter 小知识
  • 在Ubuntu 18.04 和 ROS Melodic 上编译 UFOMap
  • 跨浏览器音频录制:实现兼容的音频捕获与WAV格式生成
  • Spring Security认证流程
  • LabVIEW实现Voronoi图绘制功能
  • 【MQ篇】初识RabbitMQ保证消息可靠性
  • 信息系统项目管理工程师备考计算类真题讲解七
  • KMS工作原理及其安全性分析
  • Java Agent 注入 WebSocket 篇
  • java方法引用
  • kotlin和MVVM的结合使用总结(二)
  • 一种Spark程序运行指标的采集与任务诊断实现方式
  • CE第二次作业
  • NODE_OPTIONS=--openssl-legacy-provider vue-cli-service serve
  • Git 的基本概念和使用方式
  • C++跨平台开发要点
  • Spring AI 核心概念
  • 【Linux】网络基础和socket
  • HGDB安全版单机修改用户密码
  • spring-ai使用Document存储至milvus的数据结构
  • dockercompose文件仓库