Skip to content

图解分布式训练(八)—— ZeRO 学习

📝通俗解释:ZeRO(Zero Redundancy Optimizer)是一种用于训练超大模型的技术,它通过"分片"的方式,让多张显卡共同分担原本需要单独存储的模型数据,从而解决显存不够的问题。就像原来每个人都要背一整本书,现在大家分工,每人只背一部分,最后拼起来就能完成学习。


一、什么是 3D 并行?

3D 并行是一种组合多种并行策略的技术,可以让大型模型以非常高效的方式进行训练。

📝通俗解释:想象一下搬砖头,一个人搬很慢,3D并行就像同时用了"好几个人+不同搬运方法"——有人负责把砖头分成多份(数据并行),有人把砖头横向切开分工(张量并行),有人按楼层顺序传递(流水线并行),三种方法一起用,效率最高。


二、3D 并行策略有哪些?

  • DataParallel (DP):数据并行
  • TensorParallel (TP):张量并行(也称模型并行)
  • PipelineParallel (PP):流水线并行(也称管道并行)

2.1 DataParallel (DP)

介绍:假设有 N 张卡,每张卡都保存一个完整模型副本,每一次迭代(iteration/step)都将 batch 数据分割成 N 个等大小的 micro-batch,每张卡根据拿到的 micro-batch 数据独立计算梯度,然后调用 AllReduce 计算梯度均值,每张卡再独立进行参数更新。

举例说明

text
# 假设模型有三层:L0, L1, L2
# 每层有两个神经元   # 两张卡(每张卡都有完整模型)
GPU0:    L0 | L1 | L2   ---|----|---   a0 | b0 | c0   a1 | b1 | c1
GPU1:    L0 | L1 | L2   ---|----|---   a0 | b0 | c0   a1 | b1 | c1

📝通俗解释:DataParallel就像让两个人各自完整地做同一套试卷的不同题目,最后把答案综合起来。比如题目有10道,两人各做5道,做完后交换答案核对,这样每个人就都有了完整试卷的解题思路。


2.2 TensorParallel (TP)

介绍:每个张量都被分成多个块,因此不是让整个张量驻留在单个 GPU 上,而是张量的每个分片都驻留在其指定的 GPU 上。在处理过程中,每个分片在不同的 GPU 上分别并行处理,最终结果在步骤结束时同步。这也被称作横向并行

举例说明

text
# 假设模型有一层Linear层,包含权重矩阵W
# 两张卡,将权重矩阵按列切分
# W = [W_col0 | W_col1]
# GPU0: W_col0  GPU1: W_col1
# 输入X同时发送给两张卡,分别计算 X @ W_col0 和 X @ W_col1
# 最后结果相加:X @ W = (X @ W_col0) + (X @ W_col1)

📝通俗解释:TensorParallel就像一道数学大题,两个人分工做——一个人算前半部分,一个人算后半部分,最后把答案相加才是完整结果。原本需要一个人算整个题,现在两个人同时算,效率翻倍。


2.3 PipelineParallel (PP)

介绍:模型在多个 GPU 上垂直(层级)拆分,因此只有模型的一个或多个层放置在单个 GPU 上。每个 GPU 并行处理管道的不同阶段,并处理一小部分批处理。

举例说明

text
# 假设模型有 8 层
# 两张卡
| L0 | L1 | L2 | L3 |   | L4 | L5 | L6 | L7 |
====================   ====================
       GPU0                  GPU1

📝通俗解释:PipelineParallel就像工厂流水线——第一个人负责加工前4道工序,把半成品传给第二个人继续加工后4道工序。虽然有传递等待时间,但两个人可以同时工作,整体效率比一个人做完所有工序要高。


三、为什么需要 ZeRO?

虽然 DataParallel (DP) 因为简单易实现,所以目前应用相比于其他两种广泛,但是由于 DataParallel (DP) 需要每张卡都存储一个模型,导致 显存大小 成为 制约模型规模主要因素

既然每张卡都存储一个模型会增加模型训练过程中的显存占用,那么是否可以让每张卡训练 1/N 的模型参数,然后合并起来就是一个完整模型呢?这样,随着卡数的增加,每张卡用于模型训练的显存占用将降低,能够训练的模型也就越大。

如今训练大模型离不开各种分布式并行策略,ZeRO 系列技术就是一种显存优化的数据并行方案,旨在训练超大规模的语言模型。

📝通俗解释:这就相当于一个图书馆,原来每个分馆都要存放整套书籍(浪费空间),现在改成总馆存一套,各个分馆只借阅自己需要的部分(按需分配),这样每个分馆的书架压力就小多了。ZeRO就是让每张显卡只存储模型的一部分,而不是全部。


四、ZeRO 的核心思想是什么?

去除数据并行中的冗余参数,使每张卡只存储一部分模型状态,从而减少显存占用。

📝通俗解释:就像一个团队做项目,原来每个人都持有全部资料(冗余),现在改成每个人只负责自己那部分,需要时再找队友调用。这样既省了存储空间(显存),团队配合好的话效率也不受影响。


五、ZeRO 显存如何分配?

ZeRO 将模型训练阶段中每张卡的显存内容分为两类:

  • 模型状态:包括参数(Parameters)梯度(Gradients)优化器状态(Optimizer States),其中优化器状态占比约 75%
  • 剩余状态:除了模型状态之外的显存占用,包括激活值(Activations)、各种临时缓冲区以及无法使用的显存碎片

📝通俗解释:训练一个大模型就像开一家工厂,模型状态是核心机器设备(占大部分空间),激活值就像生产过程中的半成品(用完可以丢掉)。ZeRO优化的重点就是那些"核心机器"——尤其是其中的优化器状态,它就像机器的维护记录,占用空间最大。

举例说明

GPT-2 含有 1.5B 个参数,如果用 fp16 格式(混合精度),参数本身只需要 3GB 显存,但是模型状态(包括优化器等)实际上需要耗费 24GB!所以模型状态就成了头号显存杀手,它也是 ZeRO 的重点优化对象。


六、ZeRO 优化策略是怎么样?

针对模型状态的存储优化(去除冗余),ZeRO 使用的优化策略是分片,即每张卡只存 1/N 的模型状态量,这样系统内只维护一份完整模型状态。

6.1 介绍一下 ZeRO 优化策略有哪几种?

ZeRO 具有三个主要的优化阶段(ZeRO-1ZeRO-2ZeRO-3),它们对应于优化器状态(Optimizer States)、梯度(Gradients)和参数(Parameters)的分片,分别对应 Model States 不同程度的分割(Partition):

优化阶段优化内容显存减少倍数通信量变化
ZeRO-1分割 Optimizer States4倍与DP相同
ZeRO-2分割 Optimizer States + Gradients8倍与DP相同
ZeRO-3分割 Optimizer States + Gradients + Parameters与GPU数量成线性关系(如64卡减少64倍)增加约50%

📝通俗解释:

  • ZeRO-1:只把"优化器状态"分片,每张卡只管1/4,显存省4倍
  • ZeRO-2:再加上"梯度"也分片,显存省8倍
  • ZeRO-3:连"模型参数"也分片,卡越多省得越多(比如64张卡只占用1/64的显存) 通信量就是各卡之间传递数据的量,ZeRO-3因为需要频繁传递参数,通信量会稍微增加一些。

6.2 介绍一下 ZeRO-Offload 优化策略?

一张卡训不了大模型,根因是显存不足ZeRO-Offload 则将训练阶段的某些模型状态下放(offload)到内存以及 CPU 计算,即显存不足,内存来补。相比于昂贵的显存,内存廉价多了。

工作原理

  • GPU 部分:负责计算密集的前向传播(FWD)和反向传播(BWD)
  • CPU 部分:负责需要大量内存的优化器状态更新(Parameter Update)

流程

  1. 在 GPU 上面进行前向和后向计算,将梯度传给 CPU(可并行计算与通信)
  2. CPU 进行参数更新,再将更新后的参数传回 GPU

📝通俗解释:ZeRO-Offload就像一个"外包"策略——把显卡干不了的"体力活"(大量内存操作)外包给CPU和内存。显卡擅长快速计算(做菜),内存擅长存储(仓库),各干各擅长的,效率更高还省钱。


6.3 介绍一下 ZeRO-1 原理?

ZeRO-1:Optimizer States Partitioning (Pos)

  • 显存减少:4倍
  • 通信量:与数据并行相同

原理: Optimizer 在进行梯度更新时,会使用参数与 Optimizer States 计算新的参数。而在正向或反向传播中,Optimizer States 并不会参与其中的计算。因此,我们完全可以让每个进程只持有一小段 Optimizer States,利用这一小段 Optimizer States 更新完与之对应的一小段参数后,再把各个小段拼起来合为完整的模型参数。

实现过程

  1. 假设有 N_d 个并行的进程,ZeRO-1 将完整优化器的状态等分成 N_d 份并储存在各个进程中
  2. 当 Backward 完成之后,每个进程的 Optimizer 对自己储存的 Optimizer States(包括 Momentum、Variance 与 FP32 Master Parameters)进行计算与更新
  3. 更新过后的 Partitioned FP32 Master Parameters 会通过 All-gather 传回到各个进程中,完成一次完整的参数更新

效果:通过 ZeRO-1 对 Optimizer States 的分片化储存,7.5B 参数量的模型内存占用将由原始数据并行下的 120GB 缩减到 31.4GB。

📝通俗解释:ZeRO-1就像一个工作组,每个人只保管1/N的"工作笔记"(优化器状态),用自己这部分笔记更新对应的1/N的参数,最后大家把更新后的参数拼在一起,就得到了完整的更新结果。笔记不用每次都复制多份,自然省显存。


6.4 介绍一下 ZeRO-2 原理?

ZeRO-2:Optimizer States and Gradient Partitioning (Pos+g)

  • 显存减少:8倍
  • 通信量:与数据并行相同

原理: ZeRO-1 将 Optimizer States 分小段储存在了多个进程中,所以在计算时,这一小段的 Optimizer States 也只需要得到进程所需的对应一小段 Gradient 就可以。遵循这种原理,和 Optimizer States 一样,ZeRO-2 也将 Gradient 进行了切片:

实现过程

  1. 在一个 Layer 的 Gradient 都被计算出来后:Gradient 通过 AllReduce 进行聚合(类似于 DDP)
  2. 聚合后的梯度只会被某一个进程用来更新参数,因此其它进程上的这段 Gradient 不再被需要,可以立马释放掉(按需保留)

这样就在 ZeRO-1 的基础上实现了对 Gradient 的切分。

效果:通过 ZeRO-2 对 Gradient 和 Optimizer States 的分片化储存,7.5B 参数量的模型内存占用将由 ZeRO-1 中的 31.4GB 进一步下降到 16.6GB。

📝通俗解释:ZeRO-2在ZeRO-1基础上再优化——不仅笔记(优化器状态)分开保管,连"草稿纸"(梯度)也分开。用完的草稿纸马上扔掉,不需要每个人都备份,自然又省了一倍空间。


6.5 介绍一下 ZeRO-3 原理?

ZeRO-3:Optimizer States, Gradient and Parameter Partitioning (Pos+g+p)

  • 显存减少:与数据并行度 N_d 成线性关系
  • 通信量:增加约 50%

原理: 当 Optimizer States、Gradient 都被分布式切割分片储存和更新之后,剩下的就是 Model Parameter 了。ZeRO-3 通过对 Optimizer States、Gradient 和 Model Parameter 三方面的分割,从而使所有进程共同协作,只储存一份完整 Model States。其核心思路就是精细化通讯,按照计算需求做到参数的收集和释放。

核心机制

  • 参数收集(gather):在需要使用某层参数时,才从其他进程收集该层参数
  • 参数释放(partition):使用完毕后立即释放非本进程负责的参数量

📝通俗解释:ZeRO-3更彻底——不仅笔记和草稿纸分开,连"教材"(模型参数)本身也切开保管。需要用哪部分内容时再找对应的人要,用完马上还回去。这样每个人手里只有1/N的教材,显存占用最少。当然,这样跑来跑去需要更多通信,但为了能训练超大模型,还是值得的。


七、ZeRO Offload 后的计算流程是怎么样?

  1. GPU 阶段(FWD & BWD)

    • 在 GPU 上进行前向传播和反向传播计算
    • 为了提高效率,可以将计算和通信并行起来:GPU 在反向传播阶段,可以待梯度值填满 bucket 后,一边计算新的梯度一边将 bucket 传输给 CPU
    • 当反向传播结束时,CPU 基本上已经有最新的梯度值了
  2. CPU 阶段(Parameter Update)

    • CPU 进行参数更新
    • 将更新后的参数传回 GPU

📝通俗解释:这个流程就像一个"交接班"——显卡A(GPU)负责快速计算,算完后把结果"快递"给CPU,CPU慢慢处理更新,更新完再"快递"给显卡。快递和计算可以同时进行(流水线),不浪费时间。


八、DeepSpeed ZeRO3 内部实现初探篇

8.1 deepspeed 程序内部到底做了什么?

这个问题可以总结为两个方面:

  1. deepspeed ... <user_script>.py ... 命令的本质是什么?
  2. trainer.train(...) 里面发生了什么?

8.2 介绍一下 deepspeed 命令的本质?

deepspeed 命令的本质是一个 Python 脚本

bash
# 查找 deepspeed 命令路径
$ which deepspeed
/home/dingqiang/miniconda3/bin/deepspeed

# 查看内容
$ cat /home/dingqiang/miniconda3/bin/deepspeed
#!/home/dingqiang/miniconda3/bin/python

from deepspeed.launcher.runner import main

if __name__ == '__main__':
    main()

执行流程

  1. runner.py 会启动一个子进程执行 launcher/launch.py
  2. 然后 launch.py 会启动多个子进程执行用户脚本 <user_script>.py,子进程数量等于 GPU 数目
  3. 每个子进程都会被提供相应的 RANKLOCAL_RANK 环境变量,用来指定子进程用哪个 GPU

📝通俗解释:deepspeed命令就是一个"总调度员",它负责启动多个"小兵"(子进程)来同时执行你的训练脚本。每个小兵分配一个显卡编号(RANK),让大家分工合作。


8.3 trainer.train(...) 里面发生了什么?

Trainer.train(...) 主要干了三件事(如果用户提供了 resume_from_checkpoint 参数的话,还要从 deepspeed 专用格式的 checkpoint 中加载模型参数和优化器状态,这里暂不考虑):

  1. 初始化 deepspeed 的分布式环境:由 torch.distributed.init_process_group() 实现
  2. 用 DeepSpeedEngine(以下简称 Engine)封装模型:这是实现 ZeRO3 的核心
  3. 实现训练 loop 的代码:与普通的 PyTorch 代码一样

Engine 初始化的关键操作

  1. 注册钩子函数:递归地给模型的每个子模块注册 4 个钩子函数(pre/post forward/backward)
  2. 减少内存碎片:把切分好的参数塞进一个扁平的 buffer
  3. 注册反向传播的额外钩子函数:用于归约和切分梯度

钩子函数的作用

  • 两个 forward 相关的钩子:负责在每个子模块 forward 前后聚合重新切分参数
  • 两个 backward 相关的钩子:类似地做参数的聚合和划分
  • 反向传播的额外钩子:用于归约和切分梯度

📝通俗解释:Engine就像一个"智能包装盒",把原始PyTorch模型包起来。这个包装盒会在模型运行的前前后后自动"动手脚"——需要参数时自动召集(gather),用完后自动分给各人(partition),这些都通过"钩子"自动完成,用户完全不用手动处理。


8.4 from_pretrained() & 参数切分与聚合?

from_pretrained() 是完成预训练权重加载的函数。在单进程的 Python 程序中,它可以由 torch.load() 以及 load_state_dict() 替代。但是在 deepspeed 环境下,它有一些新的行为:

  1. 给模型参数添加新属性和方法,使它们成为"zero 参数":

    • ds_tensor:用于存放被切分的参数
    • ds_status:用于表示参数是否被切分
    • partition():用于切分参数
    • all_gather():用于聚合参数
  2. "分布式地"加载模型权重

    • 由主进程递归地加载子模块
    • 每加载完一个子模块就广播给所有进程,然后所有进程都做切分

参数切分的实现

  1. 将参数展平到一维
  2. 每个进程根据自己的 rank 找到自己存的参数片段
  3. 参数 param 切完的值保存在 param.ds_tensor,释放 param.data

参数聚合的实现

  • 调用 torch.distributed.all_gather_into_tensor()
  • deepspeed 加入了一个技巧:把一个模块的参数都装进一个连续的 buffer 来做聚合,这样可以提高 GPU 之间通信的吞吐量

📝通俗解释:加载预训练模型时,deepspeed会"偷偷"给每个参数加上"分家"的能力。加载过程是:主进程先加载,然后广播给大家,每台机器拿到后把自己那份切下来保存,其余的释放掉。聚合时则像拼图一样,把所有人的碎片拼成完整图案,为了更快,它会把碎片先拼成几大块再传递。


九、DeepSpeed ZeRO3 的 backward & step 的实现

9.1 介绍一下 backward 实现机制?

DeepSpeed 通过在反向传播过程中注册额外的钩子函数来实现梯度的分片:

  1. 梯度计算:每个 GPU 独立计算本地 mini-batch 的梯度
  2. 梯度聚合:通过 AllReduce 聚合所有 GPU 的梯度
  3. 梯度分片:聚合后的梯度只保留当前进程负责的部分,释放其他部分
  4. 分片通知:触发 ZeRO 优化器的相应钩子,进行梯度状态的更新

📝通俗解释:backward时,系统先让每个人算出自己的"解题思路"(梯度),然后大家互相比较综合,得到"标准答案"(聚合梯度)。但每个人都只需要记住和自己相关的1/N部分,其他的马上忘记(释放),这样就能省显存。


9.2 介绍一下 step 实现机制?

ZeRO-3 的参数更新(step)过程涉及参数的动态收集和释放:

  1. 参数收集(All-Gather):在更新参数前,需要先收集所有分片的参数
  2. 参数更新:使用优化器状态和梯度更新参数
  3. 参数释放(Partition):更新完成后,再次切分参数,只保留当前进程负责的部分
  4. 广播:将更新后的参数广播给所有进程

📝通俗解释:step就像"汇总考试答案"——每次更新参数前,大家先把各自保管的1/N拼成完整试卷(收集),老师批改(更新),改完后再各自分走属于自己的1/N(释放),等待下次考试。这样既保证了大家用到的都是最新参数,又让每个人只操一份心。


参考资料

基于 MIT 许可发布