ZeRO与3D并行之间的关系
我们来梳理一下 ZeRO (Stages 1, 2, 3) 和 3D 并行(数据并行 DP、张量并行 TP、流水线并行 PP)之间的关系。
核心关系:
- 3D 并行 (DP, TP, PP) 是用于分布式训练的基础策略,目的是将模型训练的计算负载和模型状态(参数、梯度、优化器状态)分散到多个 GPU 上,以应对模型过大或加速训练的需求。它们定义了如何在不同维度上切分任务和数据。
- ZeRO (Zero Redundancy Optimizer) 是一种显存优化技术,它专门针对数据并行 (DP) 带来的显存冗余问题进行优化。它通过对模型状态进行分区(Sharding) 来显著减少每个参与数据并行的 GPU 所需的显存。
可以这样理解:
- 3D 并行是“骨架”: 它定义了你的训练任务如何在多个 GPU 上分布。你可以选择只用 DP,或者 DP+TP,或者 DP+PP,或者三者结合(DP+TP+PP)。
- 数据并行 (DP): 每个 GPU 持有完整的模型副本,处理不同批次的数据。这是 ZeRO 主要作用和优化的对象。
- 张量并行 (TP): 将单个大算子(如矩阵乘法)切分到多个 GPU 上,每个 GPU 只处理一部分计算和存储一小块权重。通常用于节点内高速互联的 GPU。
- 流水线并行 (PP): 将模型的不同层分配到不同的 GPU(称为 Stage),数据像流水线一样流过。用于处理深度过大的模型。
- ZeRO 是对 DP 的“增强插件”或“优化实现”: 传统的 DP 效率不高,因为每个 GPU 都需要存储完整的模型参数、梯度和优化器状态,导致显存占用巨大且冗余。ZeRO 就是为了解决这个问题而生的。
- ZeRO 不是一种新的并行“维度”,而是对数据并行维度的一种实现方式。
- 当你决定使用数据并行时,你可以选择:
- 传统 DP: 每个 GPU 保存所有参数、所有梯度、所有优化器状态。
- 使用 ZeRO 优化的 DP:
- ZeRO Stage 1: 对优化器状态进行分区。每个 DP rank 只保存一部分优化器状态。参数和梯度仍然是完整的副本。
- ZeRO Stage 2: 在 Stage 1 基础上,对梯度也进行分区。每个 DP rank 只保存自己计算出的那部分参数对应的梯度(在 AllReduce 后)。参数仍然是完整的副本。
- ZeRO Stage 3: 在 Stage 2 基础上,对模型参数本身也进行分区。每个 DP rank 只持有完整模型参数的一部分。在计算需要时,通过通信动态获取所需的其他部分的参数。这是最节省显存的阶段。
它们如何协同工作:
在训练一个非常大的模型时(比如你提到的 72B 模型),通常会组合使用 3D 并行和 ZeRO:
- 全局视角: 你可能会将整个集群的 GPU 划分成几个流水线阶段 (PP)。
- 阶段内视角: 在每个流水线阶段内部,可能由多个 GPU 组成。这组 GPU 可能首先进行张量并行 (TP) 来处理单个大层(例如,8 个 A100 做 8 路 TP)。
- 数据并行视角: 这个经过 TP 组合的“逻辑单元”(代表了模型的一部分层)会被复制多份,构成数据并行 (DP)。而这个数据并行,几乎肯定会采用 ZeRO(尤其是 ZeRO Stage 3)来实现,而不是传统的、显存效率低下的 DP。
举例说明一个可能的组合配置:
假设你有 64 个 GPU。
-
流水线并行 (PP): 将模型分为 8 个 Stage (PP=8)。每个 Stage 分配 8 个 GPU。
-
张量并行 (TP): 在每个 Stage 内部,这 8 个 GPU 进行 8 路张量并行 (TP=8)。它们共同负责执行分配给这个 Stage 的那部分模型层的计算。
-
数据并行 (DP) & ZeRO: 在这个例子中,因为 PP=8, TP=8,总 GPU 数是 64,所以数据并行度 DP = 64 / (PP * TP) = 64 / (8 * 8) = 1。这意味着没有传统意义上的数据并行复制。
换个例子: 假设你有 128 个 GPU。
- PP = 8
- TP = 8
- DP = 128 / (8 * 8) = 2。这意味着有 2 个数据并行的副本。这两个副本会运行 ZeRO Stage 3。每个副本(由 8x8=64 个 GPU 构成,分布在 8 个 PP stage 上,每个 stage 内有 8 个 TP GPU)会处理全局批次数据的一半。在进行参数更新时,这两个 DP rank 会根据 ZeRO-3 的规则,只持有和更新部分分区的参数、梯度和优化器状态,并通过通信协调。
总结:
- DP, TP, PP 是分布式训练的策略,用于划分计算和模型。
- ZeRO 是针对 DP 策略的显存优化技术,通过分区状态来减少冗余。
- 在实际大模型训练中,通常是结合使用这些技术。ZeRO 使得在有限的显存下运行大规模数据并行成为可能,从而可以有效地与其他并行策略(TP, PP)组合,以训练巨大的模型。你不是在 ZeRO 和 3D 并行之间做选择,而是在设计 3D 并行策略时,决定数据并行部分如何实现(传统方式还是 ZeRO 优化方式)。对于大模型,几乎总是选择 ZeRO。