📝 并行理论的黑称
  • Vertical Split: 竖切, PP
  • Horizontal Split: 横切, TP

消耗显存的模块分类

Part Category Included Items Description
1 Model states optimizer states, gradients, parameters For large models, the majority of the memory is occupied by model states.
2 Residual states activation, temporary buffers, unusable fragmented memory The remaining memory is consumed by residual states.

Model State 不足分析 + ZeRO-DP 性能

Parallelism Method Why Compute Efficiency Is High / Low Why Memory Efficiency Is High / Low Core Mechanism Main Problem
Data Parallelism (DP) High compute efficiency because each GPU runs the full model independently on its own mini-batch, so computation is large-grained and continuous, and communication is mainly limited to gradient synchronization. Low memory efficiency because each GPU stores a full copy of all model states, including parameters, gradients, and optimizer states. Replicate the full model on every GPU, split data across GPUs, then synchronize gradients. Heavy memory redundancy across GPUs.
Model Parallelism (MP) Low compute efficiency because one model computation is split across multiple GPUs, which introduces fine-grained coordination, frequent synchronization, and expensive communication on the critical path. High memory efficiency because each GPU stores only a partition of the model states instead of the full model. Partition the model itself across GPUs so each GPU holds only part of the parameters / states. Communication overhead, synchronization cost, and reduced scalability.
ZeRO-DP Tries to retain the high compute efficiency of DP by preserving DP-style computation granularity and communication pattern as much as possible. Tries to achieve the high memory efficiency of MP by partitioning model states across data-parallel processes instead of fully replicating them. Keep DP-style training flow, but partition optimizer states, gradients, and eventually parameters to remove redundancy. More complex state management and dynamic communication scheduling.

ZeRO_model_state_compression_results.png

Residual State 不足分析 + ZeRO-R 性能

优化对象 问题 / 瓶颈 ZeRO-R 的做法 目标
激活值(activations) 大模型中, checkpointing 虽然有帮助, 但仍然不足; 现有 MP 方法中还存在 activation replication 通过 activation partitioning 识别并消除激活值复制; 在合适的时候将 activations 卸载到 CPU 降低 activation memory 占用
临时缓冲区(temporary buffers) 临时缓冲区大小如果设置不当, 会影响内存与计算效率的平衡 为 temporary buffers 选择合适的大小 在内存占用和计算效率之间取得平衡
内存碎片(fragmented memory) 不同 tensor 生命周期不同, 导致训练中出现内存碎片; 即使总空闲内存足够, 也可能因缺少连续内存而分配失败 根据不同 tensor 的生命周期主动管理内存, 防止 memory fragmentation 减少碎片化, 避免内存分配失败

Mixed-Precision Training

在训练的时候,为了能有更加轻量级、快速、小内存的目标,工程界往往会引入压缩精度的方法,也就是用 FP16 来代替 float32 和 float64;
但是为了目标节点的计算精度保持,优化器可能会用高精度版本保存一定的关节参数, 这个过程就叫做混合精度训练

空间复杂度分析

定义 Ψ\Psi 是模型参数总数,比如一个 7B 模型就表示参数总量是 7.5×1097.5\times 10^9B, 一般系统的内存总占用需要的是 精度位数 * 参数总量,比如我们统一用 FP16 的话,得到的内存总占用就是 2Ψ2\Psi bytes

Adam Optimizer 显存占用

根据论文的说法,adam 优化器的显存占用主要是如下内容:

  • fp16 parameters: 用于 forward/backward 的低精度模型权重
  • fp16 gradients: 用于 step backward 传播后算出来的梯度
  • fp32 master weights: 同一个参数 (parameters) 的高精度主副本,拿来真正做参数更新的,低精度的更多只是用来算 forward 和 backward 的
  • fp32 momentum: 这是 adam 优化器的一阶矩,记录的是最近一段时间这个参数的梯度方向趋势
  • fp32 variance: 粗略理解就是梯度的平方的趋势,记录最近的梯度的尺度和波动强度

从上述列表可以看出,优化器的显存可以分为两类,一类是用来做运算传播的,另一类高精度是用来做模型参数更新的,如果从训练的角度来看,更新参数是最直观可以从训练运算过程中解耦的,也就是我们不影响数据流向,先考虑把模型参数这种重量级不动产先解耦出来去除冗余

ZeRO-1: PosP_{os} Optimizer State Partitioning

接上文逻辑,为了让每一组显卡不需要对所有高精度参数都做更新,这里就让局部参数可以更新,也就是将一个分布式训练系统分成 NdN_d 等分,每个 data-parallel rank 只负责 1Nd\frac{1}{N_d} 份参数区域的更新,最后做一组 all-scatter 就获得了全部的状态更新;

按照这个设计, 我们先假设这里 Optimizer 用到的额外字节数是 KΨK\Psi bytes (根据上面的 adam 案例就是 12Ψ\Psi bytes) 参数,那么根据 ZeRO-1 的设计我们最终的压缩得到的参数尺寸就是 4Ψ+KNdΨ4\Psi + \frac{K}{N_d} \Psi bytes

注意在这一步中,所有显卡仍然拿到完整的数据并且计算完整的梯度,然后按照经典的 DDP 协议通信来实现 all-reduce 得到全局统一梯度,只是在更新数据的时候有选择地更新自己 rank 负责的数据

ZeRO-2: PgP_g Gradient Partitioning

在这一步已经将参数层面切分开了,但是优化器中的梯度分区是和参数分区一一对应的,所以这里就会有一个优化角度: 将每个 rank 分区的梯度也轻量化,或者说每个分区只负责一部分的梯度计算,也只需要接受相关梯度分区的传递数据,下面用一个例子来辅助说明:

对于一个二分区系统,我们有 rank0 和 rank1 来对数据做计算: rank0 负责分区 A rank1 负责分区 B;

按照这个设计,我们这里的 gradients 的内存占用就从 2Ψ\Psi 变成 2NdΨ\frac{2}{N_d}\Psi bytes 了

这里每个 rank 只算本地的梯度了,算完之后对梯度做分区的 reduce

ZeRO-3: PpP_p Parameter Partitioning

接下来就是顺理成章地轻量化 parameter 部分了,也就是只需要计算、传递和自己相关的部分的梯度吧