当前位置: 首页 > news >正文

如何用更少的显存训练 PyTorch 模型

文章目录

1、引言

2、自动混合精度训练

3、低精度训练

4、梯度检查点

5、通过梯度累积减小批量大小

6、张量分片与分布式训练

7、高效数据加载

8、使用 In-Place 操作

9、Activation and Parameter Offloading

10、使用更精简的优化器

11、高级策略

12、总结


1、引言

在训练大型深度学习模型(包括LLM和视觉Transformer)时,最常见的瓶颈之一就是显存消耗达到峰值。由于大多数人无法使用大规模的GPU集群,因此在本文中将概述一些技术和策略,在不牺牲模型性能和预测准确性的情况下,将显存消耗降低近20倍。请记住,这些技术中的大多数应用并不互相排斥,可以很容易地结合使用,以提高显存效率。

2、自动混合精度训练

混合精度训练结合了16位(FP16)和32位(FP32)浮点格式。其核心思想是在低精度下执行大部分数学运算,从而降低显存带宽和存储需求,同时在计算的关键环节保留必要的精度保障。通过使用FP16存储激活值和梯度,这些张量的显存占用量可减少约一半。但需注意,某些网络层或运算仍需保持FP32精度以避免数值不稳定问题。值得庆幸的是,PyTorch对自动混合精度(AMP)的原生支持极大简化了这一过程。

注意这里是混合精度训练而不是低精度训练

什么是混合精度训练?

混合精度训练结合使用16位(FP16)和32位(FP32)浮点格式以保持模型精度。通过使用16位精度计算梯度,相比全32位精度计算,这一过程可大幅加快运算速度并显著减少显存占用。这种方法在显存或计算资源受限的场景下尤为实用。

之所以采用混合精度而非低精度这一表述,是因为并非所有参数或运算都被转换为16位格式。实际上,训练过程会在32位与16位运算之间动态切换,这种精度层级的有序交替正是该技术被称为混合精度的根本原因。

如上述示意图所示,混合精度训练流程首先将权重转换为低精度格式(FP16)以加速计算,随后梯度计算在低精度环境下完成,但为确保数值稳定性,这些梯度会被重新转换为高精度格式(FP32),最终经过缩放处理的梯度将用于更新原始权重。因此,通过这种机制既能提升训练效率,又不会牺牲网络的整体精度与稳定性。

如前所述,使用 torch.cuda.amp.autocast( ) 可以轻松启用该功能,一个简单的代码示例片段如下:

import torch
from torch.cuda.amp import autocast, GradScaler# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, target in data_loader:optimizer.zero_grad()# Enable mixed precisionwith autocast():output = model(data)loss = loss_fn(output, target)# Scale the loss and backpropagatescaler.scale(loss).backward()scaler.step(optimizer)scaler.update()

3、低精度训练

如原文所述,理论上可以更进一步尝试完全使用16位低精度(而非混合精度)进行训练。但此时可能因16位浮点数的固有精度限制出现NaN值异常。为解决这一问题,业界开发了多种新型浮点格式,其中由谷歌专门为此研发的BF16应用较为广泛。简而言之,相较于标准的FP16,BF16拥有更大的动态范围——这种扩展的动态范围使其能够更精确地表示极大或极小的数值,从而更适配可能遭遇广泛数值区间的深度学习场景。虽然其较低的尾数精度在某些情况下可能影响计算准确性或引发舍入误差,但在大多数实践中对模型性能的影响微乎其微。

FP16与BF16的动态范围对比

虽然这种格式最初是为TPU开发的,但在大多数现代GPU(Nvidia Ampere架构及更高版本)也支持这种格式。大家可以使用以下方法检查您的GPU是否支持这种格式:

import torch
print(torch.cuda.is_bf16_supported())  # should print True

4、梯度检查点

即使采用混合精度与低精度训练,这些大型模型仍会生成大量中间张量,消耗可观的显存。梯度检查点技术通过在前向传播过程中选择性存储部分中间结果来解决这一问题——未被保存的中间张量将在反向传播阶段重新计算。尽管这会引入额外的计算开销,却能显著节省显存资源。

通过策略性选择需设置检查点的网络层,大家可通过动态重新计算激活值而非存储它们来减少显存使用。这种时间与内存的折中策略对于具有深层架构的模型特别有益,因为中间激活值占用了大部分内存消耗。以下是一个简单的使用示例:

import torch
from torch.utils.checkpoint import checkpoint
def checkpointed_segment(input_tensor):# This function represents a portion of your model# which will be recomputed during the backward pass.# You can create a custom forward pass for this segment.return model_segment(input_tensor)
# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)

采用该方法,在大多数情况下可使激活值的显存占用量降低40%至50%。尽管反向传播阶段因此增加了额外的计算量,但在GPU显存成为瓶颈的场景下,这种以时间换空间的策略通常是可接受的。

5、通过梯度累积减小批量大小

通过最初的方法,你可能会问自己:

为什么不干脆减少batchsize大小?

通过减小批量大小的确是减少显存占用最直接的方法,但需注意的是,这种方式在多数情况下会导致模型预测性弱于使用更大批量训练的模型。因此需要在显存限制与模型效果之间谨慎权衡。

那么如何达到平衡呢?

这正是梯度累积技术发挥作用之处!该方法通过在训练过程中虚拟增大有效批量规模:其核心原理是先在较小的批量上计算梯度,并经过多次迭代的累积(通过采用累加或平均方式),而非在每批次处理后立即更新模型参数。当累积梯度达到目标“虚拟”批量规模时,才使用聚合后的梯度一次性完成模型权重的更新。

这种技术的一个主要缺点是大大增加了训练时间。

6、张量分片与分布式训练

对于单个GPU无法容纳的庞大训练模型(即使经过上述优化),完全分片数据并行(FSDP)是不可或缺的。FSDP将模型参数、梯度和优化器状态分散到多个GPU上。这不仅能将巨大的模型放入显存,还能通过更好地分配通信开销提高训练效率。

FSDP不在每个GPU上维护模型的完整副本,而是在可用设备之间分配模型参数。在执行前向或后向传递时,只有相关的分片被加载到显存中。这种分片机制大大降低了对每台设备显存的需求,结合上述技术,在某些情况下甚至可以将显存需求降低10倍。

Tensor Parallel

样例如下:

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()
# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)

7、高效数据加载

在显存优化实践中,数据加载环节常被忽视。虽然优化重点通常集中在模型内部结构与计算过程上,但低效的数据处理可能引发不必要的性能瓶颈,同时影响显存占用与训练速度。若不确定如何优化数据加载器,可遵循以下经验法则:优先启用固定内存(Pinned Memory)与多工作进程(Multiple Workers)配置。

from torch.utils.data import DataLoader# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(dataset,batch_size=64,shuffle=True,num_workers=4,      # Adjust based on your CPU capabilitiespin_memory=True     # Enables faster host-to-device transfers
)

8、使用 In-Place 操作

在张量运算中,若未谨慎管理内存,每次操作都可能生成新对象。原地(In-Place)操作通过直接修改现有张量而非创建副本,可有效减少内存碎片化与总体内存占用。这种特性尤其有利于降低迭代训练循环中的临时内存分配开销。例如:

import torch
x = torch.randn(100, 100, device='cuda')
y = torch.randn(100, 100, device='cuda')
# Using in-place addition
x.add_(y)  # Here x is modified directly instead of creating a new tensor

9、Activation and Parameter Offloading

即便综合运用前述所有优化技术,在训练超大规模模型时,仍可能因海量中间激活值的瞬时占用而触及GPU显存容量极限。此时,中间数据卸载技术可作为额外的安全阀机制——其核心思路是将部分非即时必需的中间数据临时转换至CPU内存,从而为GPU显存腾出关键空间,确保训练流程持续进行。

我们通过策略性将部分激活值和或模型参数卸载至CPU内存,从而将GPU显存专用于核心计算任务。虽然如DeepSpeed、Fabric等专业框架已内置管理此类数据迁移的机制,大家仍可通过以下方式自主实现该功能。

def offload_activation(tensor):# Move tensor to CPU to save GPU memoryreturn tensor.cpu()def process_batch(data):# Offload some activations explicitlyintermediate = model.layer1(data)intermediate = offload_activation(intermediate)intermediate = intermediate.cuda()  # Move back when neededoutput = model.layer2(intermediate)return output

10、使用更精简的优化器

并非所有优化器对内存的需求均等。以广泛使用的Adam优化器为例,其针对每个模型参数需额外维护两个状态变量(均值与方差),导致内存占用倍增。相比之下,采用无状态优化器(如SGD)可将参数总量减少近三分之二——这对于训练大语言模型(LLMs)及其他大规模模型具有显著意义。

尽管普通SGD优化器存在收敛性能较弱的缺陷,但通过引入余弦衰减学习率调整策略(Cosine Decay Learning Rate Scheduler)可有效补偿这一不足。简而言之:

# instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)

通过这一调整,大家可以在显著改变峰值内存占有量的同时(具体取决于实际任务需求),仍能保持模型精度接近97%的水平。

11、高级策略

虽然上面列出的技术确实为我们奠定了坚实的基础,但我还想列出一些其他高级策略,我们可以考虑将 GPU 提升到极限:

  • 内存剖析和高速缓存管理

如果无法测量,就很难优化。PyTorch 提供了一些检查 GPU 内存使用情况的默认实用程序。使用方法如下:

import torch
# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))
# free up cached memory that’s no longer needed by PyTorch
torch.cuda.empty_cache()
  • 使用TorchScript进行JIT编译

PyTorch 的即时(JIT)编译器使大家使用 TorchScript 将 Python 模型转换为优化的、可序列化的程序。通过优化内核启动和减少开销,这种转换可以提高内存和性能。您可以通过以下方式轻松访问它:

import torch
# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)
# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)

尽管框架原生方法已能实现基础功能,但模型编译技术通常能带来更深层次的性能优化。

  • 自定义内核融合

编译的另一个主要好处是将多个操作融合到一个内核中。这有助于减少内存读/写,提高整体吞吐量。融合后的操作如下:

  • 使用torch.compile()进行动态内存分配

再次从编译中获益--使用 JIT 编译器可通过利用跟踪和图形优化技术的编译时优化来优化动态内存分配,从而进一步压缩内存并提高性能,尤其是在大型模型和Transformer架构中。

12、总结

随着 GPU 和云计算变得异常昂贵,只有充分利用现有资源才有意义。这有时可能意味着要在单个 GPU 工作站/笔记本电脑上对 LLM 或视觉Transformer进行训练/微调。上面列出的技术是研究人员/专业人士在算力紧张的情况下进行训练所使用的众多策略中的一部分。

参考资料:AI算法之道

http://www.xdnf.cn/news/278443.html

相关文章:

  • 【Java JUnit单元测试框架-60】深入理解JUnit:Java单元测试的艺术与实践
  • Spring AI 实战:第九章、Spring AI MCP之万站直通
  • HTML5实战指南:语义化标签与表单表格高级应用
  • AI日报 · 2025年5月04日|Hugging Face 启动 MCP 全球创新挑战赛
  • 《工业社会的诞生》章节
  • 相向双指针-16. 最接近的三数之和
  • 基于AWS Marketplace的快速解决方案:从选型到部署实战
  • OpenFAST 开源软件介绍
  • 大学之大:高丽大学2025.5.4
  • Java并发编程-多线程基础(三)
  • 在 Ubuntu 系统中,查看已安装程序的方法
  • Redis-----认识NoSQL
  • 网络开发基础(游戏)之 心跳机制
  • C++ 多态:原理、实现与应用
  • 科学养生,健康生活
  • Python容器与循环:数据处理的双剑合璧
  • 虚函数 vs 纯虚函数 vs 静态函数(C++)
  • 原型模式(Prototype Pattern)
  • drawDB:打造高效数据库设计流程
  • Go-Spring 全新版本 v1.1.0
  • 潮乎盲盒商城系统全开源多级分销推广海报奖品兑换试玩概率OSS云存储多端源码
  • 工业大模型:从设备诊断到工艺重构
  • 从入门到登峰-嵌入式Tracker定位算法全景之旅 Part 3 |混合定位实战:Wi-Fi RTT / LoRa / BLE RSSI AoA 多源融合
  • node.js为什么产生?
  • Qt基础知识记录(终篇)
  • 前端面试每日三题 - Day 24
  • SpringCloud教程 — 无废话从0到1逐步学习
  • 小刚说C语言刷题—1324扩建鱼塘问题
  • C++基础代码解释
  • dubbo 参数校验-ValidationFilter