Skip to content

PyTorch 分布式计算坑/bug梳理篇

来源:AiGC面试宝典 作者:宁静致远 日期:2024年01月27日


动机

PyTorch用的人越来越多,大的模型都需要用GPU或者多张GPU甚至多节点多卡进行分布式计算,但是坑也很多。本文主要介绍读者在进行PyTorch分布式计算时所遇到的坑/bug的梳理及填坑记录。

📝 通俗解释:就像一个人搬不动重物,需要多人协作一样。训练大模型就像搬重物,一张GPU搬不动,需要多张GPU一起搬。但在多人协作的过程中,会出现各种配合问题,这篇文章就是记录这些问题的。


一、使用DistributedDataParallel(分布式并行)时,显存分布不均衡问题

问题描述

如果用DistributedDataParallel(分布式并行)的时候,每个进程单独跑在一个GPU上,多个卡的显存占用应该是均匀的。比如像这样:

Tue Oct 2223:34:40 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40      Driver Version: 430.40      CUDA Version: 10.1       |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 208...  Off  | 00000000:04:00.0 Off |                  N/A |
| 56%   66C    P2   225W / 250W |   7171MiB / 11019MiB |     95%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:05:00.0 Off |                  N/A |
| 59%   68C    P2   234W / 250W |   7167MiB / 11019MiB |     94%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce RTX 208...  Off  | 00000000:81:00.0 Off |                  N/A |
| 45%   59C    P2   240W / 250W |   7153MiB / 11019MiB |     94%      Default |
+-------------------------------+----------------------+----------------------+
|   3  GeForce RTX 208...  Off  | 00000000:85:00.0 Off |                  N/A |
| 62%   70C    P2   233W / 250W |   7161MiB / 11019MiB |     94%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     30917      C   /usr/bin/python                            7161MiB |
|    1     30918      C   /usr/bin/python                            7157MiB |
|    2     30919      C   /usr/bin/python                            7143MiB |
|    3     30920      C   /usr/bin/python                            7151MiB |
+-----------------------------------------------------------------------------+

注:在Distributed模式下,相当于你的代码分别在多个GPU上独立运行,代码都是设备无关的。比如你写t = torch.zeros(100, 100).cuda(),在4个进程上运行的程序会分别在4个GPU上初始化t。所以显存的占用会是均匀的。

📝 通俗解释:想象4个人各自负责搬运一个仓库的货物,理论上每个人应该搬同样多的货物(显存均匀分布)。但有时候会出现一个人(GPU 0)帮其他人也搬了一部分,导致自己累得不行(显存占用高),其他人却比较轻松。

然而,有时会发现另外几个进程会在0卡上占一部分显存,导致0卡显存出现瓶颈,可能会导致cuda-out-of-memory错误。比如这样:

GPU状态表:

GPUNameMemory-UsageGPU-Util
0GeForce RTX 208...10846MiB / 11019MiB92%
1GeForce RTX 208...7169MiB / 11019MiB94%
2GeForce RTX 208...7157MiB / 11019MiB92%
3GeForce RTX 208...7159MiB / 11019MiB93%

进程列表:

GPUPIDProcess nameGPU Memory Usage
031570/usr/bin/python7235MiB
031571/usr/bin/python1199MiB
031572/usr/bin/python1199MiB
031573/usr/bin/python1199MiB
131571/usr/bin/python7159MiB
231572/usr/bin/python7147MiB
331573/usr/bin/python7149MiB

可以看到,GPU 0上跑了4个Python进程,而其他GPU各只跑了1个进程。

问题定位

该问题主要由以下代码导致:

python
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])

注:上述代码运行后,程序load一个预训练模型的时候,torch.load()会默认把load进来的数据放到0卡上,这样4个进程全部会在0卡占用一部分显存。

📝 通俗解释:就像你让4个人去仓库取货(加载预训练模型),但仓库只认第一个来的人(默认使用GPU 0),所以其他3个人取的东西也被记到了第一个人的账上,导致第一个人需要搬运的货物特别多(显存占用高)。

解决方法

把load进来的数据map到CPU上:

python
checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])

📝 通俗解释:相当于让所有人先从公共仓库(CPU内存)取货,然后再各自从公共仓库搬到自己的仓库(对应的GPU),这样就不会出现一个人帮所有人搬货的情况了。


二、使用PyTorch实现同步梯度更新时,自研数据接口导致第一个epoch结尾处程序卡死问题

如果使用PyTorch实现同步梯度更新,数据接口是自己编写的话,一定要注意保证每张卡分配的batch数量是一样的。因为如果某张卡少了一个batch的话,其他卡就会等待,从而程序卡在torch.all_reduce()上。最后的情况就会出现在第一个epoch结尾处程序卡住,而且没有报错信息。

📝 通俗解释:想象4个人搬货物,每个人需要搬10趟(10个batch)。如果其中一个人只搬了9趟就停了,其他3个人搬完自己的10趟后,发现还有货物没人搬,但又不知道那个人去哪了,就会一直等着他,造成所有人都卡住不动了。torch.all_reduce()就是那个让大家同步的工具,但它不知道有人提前跑了。

📝 通俗解释torch.all_reduce()是分布式计算中的一个同步操作,类似于所有人必须到齐才能继续下一步的"点名"机制。如果有人没到(某个GPU的batch还没处理完),其他人就会一直等下去。


三、微调大模型时,单机2卡正常但4卡及以上卡住问题

在微调大模型的时候,单机2卡的时候正常训练,但是采用4卡及以上,就会卡住,卡在读完数据和开始训练之间。

排查步骤

  1. 确认GPU通信正常:先确认几张卡都能正常使用和通信
  2. 检查batchsize分配:看看是不是batchsize分配之类的问题导致无限等待某一张卡了
  3. 最小化测试:只留4条数据,每张卡只跑一条数据试试看

📝 通俗解释

  • 步骤1:就像检查4个人之间能不能正常喊话(通信)
  • 步骤2:看看是不是分配任务的时候不均匀,有人太多有人太少
  • 步骤3:简化问题,就像让每个人只搬一件货,看看能不能正常完成

常见原因

  • 数据并行问题:不同GPU处理的数据量不一致
  • 同步等待:某个GPU处理速度慢,导致其他GPU等待
  • 模型复制问题:模型没有正确复制到所有GPU上

📝 总结:PyTorch分布式计算主要有三大坑:

  1. 显存不均衡 → 加载模型时指定CPU
  2. 自研数据接口导致卡死 → 保证每张卡batch数相同
  3. 多卡训练卡住 → 排查通信、batch分配、模型复制问题

参考来源:知识星球 - AiGC面试宝典

基于 MIT 许可发布