当前位置: 首页 > news >正文

pytorch checkpointing

是一种在训练深度神经网络时通过增加计算代价来换取显存优化的技术。它的核心思想是:在反向传播过程中动态重新计算中间激活值(activations),而不是保存所有中间结果。这对于显存受限的场景(如训练大型模型)非常有用。

直接上代码:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint# 1. 定义一个简单的 FFN 模型
class SimpleFFN(nn.Module):def __init__(self, input_dim=128, hidden_dim=256, output_dim=10):super().__init__()self.linear1 = nn.Linear(input_dim, hidden_dim)self.linear2 = nn.Linear(hidden_dim, hidden_dim)self.linear3 = nn.Linear(hidden_dim, output_dim)self.relu = nn.ReLU()def forward(self, x):# 2. 定义一个自定义的前向传播函数(用于 checkpoint)def custom_forward(x):x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.relu(x)x = self.linear3(x)return x# 3. 使用 checkpoint 包装前向传播return checkpoint(custom_forward, x)# 4. 初始化模型和数据
model = SimpleFFN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()# 模拟输入数据
input_data = torch.randn(64, 128)  # batch_size=64, input_dim=128
target = torch.randn(64, 10)       # 模拟目标输出# 5. 前向传播、损失计算和反向传播
output = model(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
  • 在反向传播时,custom_forward 会被重新调用,从输入 x 重新计算中间激活值,从而节省显存。
  • 显存占用:仅保存 linear3 的输出和 x,中间激活值在反向传播时动态计算。
  • 需要多次前向计算激活值,训练速度可能变慢
http://www.xdnf.cn/news/293347.html

相关文章:

  • 交换机工作原理(MAC地址表、VLAN)
  • P4168 [Violet] 蒲公英 Solution
  • 生物化学笔记:神经生物学概论10 运动节律的控制 运动时脑内活动 运动系统疾病及其治疗(帕金森、亨廷顿)
  • 【OSPF协议深度解析】从原理到企业级网络部署
  • 第15章:双星入侵与时间的迷雾
  • AIGC工具平台-图片转换线稿
  • 「OC」源码学习——对象的底层探索
  • 混搭文化数字社会学家解读,创新理解AI社会学网络社会学与数字人类学最新研究进展社会结构社会分层数字文化数字经济
  • 网络编程套接字(一)
  • PriorityQueue
  • 使用 Semantic Kernel 快速对接国产大模型实战指南(DeepSeek/Qwen/GLM)
  • Web前端开发:Grid 布局(网格布局)
  • ts学习(1)
  • 2024年408真题及答案
  • C++ 外观模式详解
  • php8 枚举使用教程
  • 稀疏性预测算法初步
  • 健康养生:从微小改变开始
  • 【YOLO11改进】改进Conv、颈部网络STFEN、以及引入PIOU用于小目标检测!
  • 基于Vue3开发:打造高性能个人博客与在线投票平台
  • 【MATLAB例程】基于RSSI原理的Wi-Fi定位程序,N个锚点(数量可自适应)、三维空间,轨迹使用UKF进行滤波,附代码下载链接
  • 反射-探索
  • CASS 3D使用等高线修改插件导致修后等高线高程变化的问题
  • 当前人工智能领域的主流高级技术及其核心方向
  • 10.施工测量
  • 引领变革的“Vibe Coding”:AI辅助编程的崛起与挑战
  • 某信服EDR3.5.30.ISO安装测试(一)
  • printf的终极调试大法
  • 分析 Docker 磁盘占用
  • FTP/TFTP/SSH/Telnet