当前位置: 首页 > news >正文

pytorch 中前向传播和后向传播的自定义函数

系列文章目录


文章目录

  • 系列文章目录
  • 一、torch.autograd.function
  • 代码实例


  在开始正文之前,请各位姥爷动动手指,给小店增加一点访问量吧,点击小店,同时希望我的文章对你的学习有所帮助。本文也很简单,主要讲解pytorch的前向传播张量计算,和后向传播获取梯度计算。


一、torch.autograd.function

每一个原始的自动求导运算实际上是两个对 Tensor 操作的函数

  • forward 函数计算输入Tensor,一些列操作后得到输出Tensor
  • backward 接收输出 Tensor ,获取某个标量的梯度,并且计算输入Tensor相对于相同标量的梯度值。
    使用 apply 执行相应的运算

代码实例

  这个实例实现了重写line的功能,在以后的深度学习和构建扔工神经网络中常常使用。对 line 类重构,两个方法 forward 和 backward 都是静态的。实现的功能就是把三个张量运算: w * x + b.代码中在 return 中体现。

  • forward 传递的 ctx 用于保存上下文的管理器,调用 ctx.save_for_backward(变量名) 可以存储变量,调用ctx.saved_tensors 可以把对应的张量取出来。
  • grad_output 是上一层的梯度,返回回来应该遵循链式法则。
  • 导数计算:把 y 看做是因变量(编程中省略这个变量,具体体现 w * x + b),w, x, b 都看做是自变量。使用高数中的求导公式,大家就知道乘的系数是什么了。
import torchclass line(torch.autograd.Function):@staticmethoddef forward(ctx,w,x,b):# 第一个参数是管理器,对变量进行存储# y = w*x+bctx.save_for_backward(w,x,b)# 定义前向运算return w*x+b@staticmethoddef backward(ctx, grad_output):# 上下文管理器,第二个参数是上一级梯度,表达了一个链式法则# 我们计算梯度,需要乘上一级梯度w,x,b = ctx.saved_tensors# dy/dw = xgrad_w = grad_output * x# dy/dx = wgrad_x = grad_output * w# dy/db = 1grad_b = grad_output * 1return grad_w,grad_x,grad_bw = torch.randn(2,2,requires_grad=True)
x = torch.randn(2,2,requires_grad=True)
b = torch.randn(2,2,requires_grad=True)# 调用重写的line函数
out = line.apply(w,x,b)
out.backward(torch.ones(2,2))print("x 的内容:",x)
print("w 的内容:",w)
print("b 的内容:",b)
print("grad_x",x.grad)
print("grad_w",w.grad)
print("grad_b",b.grad)
图 1求导获取的梯度
通过图 1 可知,y 对 x 方向的导数就是 w,y 对 w 方向的导数就是 x, y 对 b 的导数是 1 。大家可以结合图片来理解。我们可以把张量抽象看作是一个变量,这样可以唤醒我们远古的高数知识。
http://www.xdnf.cn/news/981703.html

相关文章:

  • vscode界面设置透明度--插件Glasslt-VSC
  • 【DETR目标检测】ISTD-DETR:一种基于DETR与超分辨率技术的红外小目标检测深度学习算法
  • 《HarmonyOSNext弹窗:ComponentContent动态玩转企业级弹窗》
  • 新闻类鸿蒙应用全链路测试实践:性能、兼容性与体验的深度优化
  • React Context 性能问题及解决方案深度解析
  • 【普及/提高−】P1025 ——[NOIP 2001 提高组] 数的划分
  • Cilium动手实验室: 精通之旅---23.Advanced Gateway API Use Cases
  • codeforces C. Devyatkino
  • Java并发工具包
  • 【59 Pandas+Pyecharts | 淘宝华为手机商品数据分析可视化】
  • 深度解读谷歌Brain++液态神经网络:重塑动态智能的流体计算革命
  • Gogs:一款极易搭建的自助 Git 服务
  • [Java恶补day22] 240. 搜索二维矩阵Ⅱ
  • React第六十节 Router中createHashRouter的具体使用详解及案例分析
  • android studio向左向右滑动页面
  • Babylon.js引擎
  • MMDG++:构筑多模态人脸防伪新防线,攻克伪造攻击与场景漂移挑战
  • java面向对象高级部分
  • 大数据服务器和普通服务器之间的区别
  • LDStega论文阅读笔记
  • 【基于阿里云上Ubantu系统部署配置docker】
  • RawTherapee:专业RAW图像处理,免费开源
  • 【AI智能体】Coze 数据库从使用到实战操作详解
  • Docker Compose完整教程
  • day51python打卡
  • AI时代的行业重构:机遇、挑战与生存法则
  • Spring Boot + MyBatis日志前缀清除方法
  • Grounding Language Model with Chunking‑Free In‑Context Retrieval (CFIC)
  • mysql如何快速生成测试大数据库
  • Java高频面试之并发编程-27