当前位置: 首页 > news >正文

pytorch基本运算-梯度运算:requires_grad_(True)和backward()

引言

前序学习进程中,已经对pytorch基本运算中的求导进行了基础讨论,相关文章链接为:

导数运算pytorch基本运算-导数和f-string-CSDN博客

实际上,求导是微分的进一步计算,要想求导的前一步其实是计算微分:

导数表达式:
f ′ ( x ) 或 d y d x f^{'}(x) 或 \frac{dy}{dx} f(x)dxdy
​微分表达式:
f ′ ( x ) d x 或 d y = f ′ ( x ) d x f^{'}(x) dx 或 {dy=f^{'}(x) dx} f(x)dxdy=f(x)dx
导数是某一点处的变化率,微分是某一点附近的变化量。
如果一个函数在多个点进行导数求解,或者说子安多维度上进行导数计算,实际上就是在求梯度。

pytorch自动微分获取梯度

为完整展示pytorch的梯度计算功能,将测试分为以下部分。

初始定义

首先是引入模块,完成变量定义:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)

这里的输出结果是:

x= tensor([0., 1., 2.])

需要说明的是,因为pytorch默认对浮点数进行求导,所以定义变量的时候,pyorch.arange()使用了3.0而不是整数3。
紧接着,需要对变量执行梯度运算。

梯度运算标定

梯度运算标定的目的是,声明要对x进行梯度运算。任何没有经过提前标定的量,都不能正常执行梯度运算。

# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)

梯度标定使用requires_grad_(True),就像对话一样,需要求梯度_(需要)。
代码运行的效果为:

z= tensor([0., 1., 2.], requires_grad=True)

下一步是定义一个函数。

函数定义

这里定义一个简单函数:
f ( x ) = 2 x 2 f(x)=2x^{2} f(x)=2x2
具体定义代码为:

# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)

计算微分对函数开展才有意义,所以必须定义函数,这里只是一个示例,也可以是其他函数。
torch.dot()函数的计算规则为:对位相乘然后求和。
代码运行效果为:

m= tensor(10., grad_fn=)

这里输出了两个部分:
第一部分是10,就是元素对位相乘后求和的效果(2X0X0+2X1X1+2X2X2=10)。
第二部分是grad_fn=,grad_fn的意思是grad_function,就是求导函数的意思,后面的MulBackward0是对求导函数的具体定义。
MulBackward0 表示这是一个乘法操作的梯度函数,具体拆开来:multiplication-backward,字面意思解释:乘法-反向传播。
这就是pytorch自动微分的核心机制:它可以自动测算求导函数的类型,比如这是一个自变量相乘的函数,并且指出要用哪种方法,比如这里要用反向传播法。
到这一步还无法计算微分,只是通过输出效果知道用反向传播方法计算微分,然后就是正式使用反向传播方法计算微分。

梯度计算

微分计算使用的代码为:

# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)

这里用了两步,第一步是定义对函数m调用backward方法求倒数,然后具体是对x求导数,所获得计算结果为:

n= None
k= tensor([0., 4., 8.])

n对应的其实是方法定义,k才是具体的对x的求导效果。

实际上到这一步,如何用pytorch直接计算导数已经非常清晰:先要标定梯度计算的变量,然后要对函数声明梯度计算的方法,最后直接计算梯度。完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)

新的函数

未计算对新函数进行求导运算,需要提前将梯度清零,避免梯度计算效果彼此叠加,出现预料之外的效果。

梯度清零

代码为:

# 梯度清零
kk=x.grad.zero_()
print('kk=',kk)

代码运行效果为:

kk= tensor([0., 0., 0.])

定义新函数

代码为:

# 定义新函数
hh=x.sum()
print('hh=',hh)

这里使用了求和函数sim(),代码运行效果为:

hh= tensor(3., grad_fn=)

这里也输出了两个部分:
第一部分是3,就是元素求和的效果(0+1+2=3)。
第二部分是grad_fn=,grad_fn的意思是grad_function,就是求导函数的意思,后面的SumBackward0是对求导函数的具体定义。
SumBackward0 表示这是一个加法操作的梯度函数,具体拆开来:Sum-backward,字面意思解释:加法-反向传播。

导数计算

此时可以直接计算导数,代码为:

# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

代码运行效果为:

nn= None
tt= tensor([1., 1., 1.])

因为是各个变量直接叠加,所以每个变量前的系数都是1,所以导数运算的结果是[1.0,1.0,1.0].
此时的完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)
# 梯度清零
kk=x.grad.zero_()
print('kk=',kk)
# 定义新函数
hh=x.sum()
print('hh=',hh)
# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

完整的输出效果为:

x= tensor([0., 1., 2.])
z= tensor([0., 1., 2.], requires_grad=True)
m= tensor(10., grad_fn=)
n= None
k= tensor([0., 4., 8.])
kk= tensor([0., 0., 0.])
hh= tensor(3., grad_fn=)
nn= None
tt= tensor([1., 1., 1.])

梯度清零操作的讨论

前述有一个梯队清零的操作,如果没有这步操作,输出效果会如何变化,这里直接给出完整代码来测试。给出完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)
# 梯度清零
#kk=x.grad.zero_()
#print('kk=',kk)
# 定义新函数
hh=x.sum()
print('hh=',hh)
# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

此时的输出效果为:

x= tensor([0., 1., 2.])
z= tensor([0., 1., 2.], requires_grad=True)
m= tensor(10., grad_fn=)
n= None
k= tensor([0., 4., 8.])
hh= tensor(3., grad_fn=)
nn= None
tt= tensor([1., 5., 9.])

这里可以看到sum()函数的梯度输出为:[1.,5.,9.],这个结果的来源其实是:[0., 4., 8.]+[1., 1., 1.]=[1., 5., 9.]。
此处可见,及时将梯度清零很有必要。

总结

掌握了通过python+pytorch执行梯度运算的基本技巧。

http://www.xdnf.cn/news/1033507.html

相关文章:

  • 多个项目的信息流如何统一与整合
  • Spring AI Chat Tool Calling 指南
  • MySQL使用EXPLAIN命令查看SQL的执行计划
  • 13.20 LangChain多链协同架构实战:LanguageMentor实现67%对话连贯性提升
  • [每周一更]-(第144期):Go 定时任务的使用:从基础到进阶
  • mysql 创建大写字母的表名失败
  • HarmonyOS 组件复用 指南
  • React中使用Day.js指南
  • ABC410 : F - Balanced Rectangles
  • MIB 树的来源与实现深度解析
  • 计算机网络学习笔记:运输层概述UDP、TCP对比
  • Arduino入门教程​​​​​​​:4、打印字符到电脑
  • 疫菌QBD案例
  • Gartner《Build Scalable Data Products With This Step-by-Step Framework》学习报告
  • Linux系统安装MongoDB 8.0流程
  • 树莓派智能小车红外避障实验指导书
  • 当遇到“提交失败:404”的问题时,通常表明前端请求的URL无法正确匹配到后端的Servlet或资源。
  • 区间合并:区间合并问题
  • 前端与协议
  • 掌握应用分层:高内聚低耦合的艺术
  • 闲鱼与淘宝跨平台运营的自动化趋势
  • java 设计模式_行为型_17观察者模式
  • 【游资悟道】陈小群成长历史与股市悟道心法
  • Java面向对象this关键字和static关键字
  • Python 爬虫入门 Day 3 - 实现爬虫多页抓取与翻页逻辑
  • android关于native中Thread类的使用
  • Linux 系统目录结构概述-linux024
  • Tauri(2.5.1)+Leptos(0.8.2)开发自用桌面小程序
  • 系统设计基本功:理解语义
  • 【Linux】Linux多路复用-epoll