当前位置: 首页 > ds >正文

Pytorch的梯度控制

在之前的实验中遇到一些问题,因为之前计算资源有限,我就想着微调其中一部分参数做,于是我误打误撞使用了with torch.no_grad,可是发现梯度传递不了,于是写下此文来记录梯度控制的两个方法与区别。

在PyTorch中,控制梯度计算对于模型训练和微调至关重要。这里区分两个常用方法:

1. tensor.requires_grad = False

  • 目标: 单个张量(通常是模型参数 nn.Parameter)。
  • 行为:
    • “参数冻结”:这个张量本身不会计算梯度 (.gradNone)。
    • “参数不更新”:优化器不会更新这个张量。
    • “梯度可穿透”:如果它参与的运算的输入是 requires_grad=True 的,梯度仍然会通过这个运算传递给输入。它不阻碍梯度流向更早的可训练层。
  • 场景:
    • 微调:冻结预训练模型的某些层,只训练其他层。
    • 例子:pretrained_layer.weight.requires_grad = False

2. with torch.no_grad():

  • 目标: 一个代码块 (with 语句块内部)。
  • 行为:
    • “全局梯度关闭”(块内):块内所有新创建的张量默认 requires_grad=False
    • “不记录计算图”:块内的运算不被追踪,不构建反向传播所需的计算图。
    • “梯度截断”:梯度流到这个块的边界就会停止,无法通过块内的操作继续反向传播
  • 场景:
    • 模型评估/推理 (Inference/Evaluation):不需要梯度,节省内存和计算。
    • 执行不需要梯度的任何计算。
    • 例子:
     with torch.no_grad():outputs = model(inputs)# ...其他评估代码
    

核心区别速记:

特性requires_grad=Falsewith torch.no_grad():
谁不更新?这个参数自己(块内)没人更新
梯度能过吗?能过!不能过! (被截断)
影响范围?单个张量整个代码块

一句话总结:

  • 想让某个参数不更新但梯度能流过,用 requires_grad=False
  • 想让一段代码完全不计算梯度也不让梯度流过,用 with torch.no_grad()

搞清楚这两者的区别,能在PyTorch中更灵活地控制模型的训练过程!

http://www.xdnf.cn/news/9911.html

相关文章:

  • 火山引擎扣子系列
  • vr中风--数据处理模型搭建与训练2
  • NLP学习路线图(十一):词干提取与词形还原
  • HTTP/HTTPS与SOCKS5三大代理IP协议,如何选择最佳协议?
  • 长安链起链调用合约时docker ps没有容器的原因
  • WPF prism
  • Arbitrary Response Filter Design and Analysis--任意响应滤波器设计与分析(待完成)
  • DexGarmentLab 论文翻译
  • CPP中CAS std::chrono 信号量与Any类的手动实现
  • Java四种访问权限修饰符详解
  • 霹雳吧啦Wz_深度学习-图像分类篇章_1.1 卷积神经网络基础_笔记
  • 【MQTT】
  • NUMA 架构科普:双路 CPU 系统是如何构建的?
  • 快速上手shell条件测试
  • Practice 2025.5.29 —— 二叉树进阶面试题(1)
  • 聊聊 Metasploit 免杀
  • 数字人引领政务新风尚:智能设备助力政务服务
  • OpenCV计算机视觉实战(9)——阈值化技术详解
  • 【仿生系统】qwen的仿生机器人解决方案
  • AI产品风向标:从「工具属性」到「认知引擎」的架构跃迁​
  • 国芯思辰| 霍尔电流传感器AH811为蓄电池负载检测系统安全护航
  • Java 实现下载指定minio目录下的所有内容到远程机器
  • ssm学习笔记(尚硅谷) day1
  • 生成式人工智能:重构软件开发的范式革命与未来生态
  • 预处理,咕咕咕
  • Cesium 展示——获取鼠标移动、点击位置的几种方法
  • 第四章、自平衡控制
  • 【Ubuntu远程桌面】
  • .NET WinForm图像识别二维码/条形码
  • 从零开始的数据结构教程(六) 贪心算法