PyTorch 分布式计算坑/bug梳理篇
来源:AiGC面试宝典 作者:宁静致远 日期:2024年01月27日
动机
PyTorch用的人越来越多,大的模型都需要用GPU或者多张GPU甚至多节点多卡进行分布式计算,但是坑也很多。本文主要介绍读者在进行PyTorch分布式计算时所遇到的坑/bug的梳理及填坑记录。
📝 通俗解释:就像一个人搬不动重物,需要多人协作一样。训练大模型就像搬重物,一张GPU搬不动,需要多张GPU一起搬。但在多人协作的过程中,会出现各种配合问题,这篇文章就是记录这些问题的。
一、使用DistributedDataParallel(分布式并行)时,显存分布不均衡问题
问题描述
如果用DistributedDataParallel(分布式并行)的时候,每个进程单独跑在一个GPU上,多个卡的显存占用应该是均匀的。比如像这样:
Tue Oct 2223:34:40 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40 Driver Version: 430.40 CUDA Version: 10.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 GeForce RTX 208... Off | 00000000:04:00.0 Off | N/A |
| 56% 66C P2 225W / 250W | 7171MiB / 11019MiB | 95% Default |
+-------------------------------+----------------------+----------------------+
| 1 GeForce RTX 208... Off | 00000000:05:00.0 Off | N/A |
| 59% 68C P2 234W / 250W | 7167MiB / 11019MiB | 94% Default |
+-------------------------------+----------------------+----------------------+
| 2 GeForce RTX 208... Off | 00000000:81:00.0 Off | N/A |
| 45% 59C P2 240W / 250W | 7153MiB / 11019MiB | 94% Default |
+-------------------------------+----------------------+----------------------+
| 3 GeForce RTX 208... Off | 00000000:85:00.0 Off | N/A |
| 62% 70C P2 233W / 250W | 7161MiB / 11019MiB | 94% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 30917 C /usr/bin/python 7161MiB |
| 1 30918 C /usr/bin/python 7157MiB |
| 2 30919 C /usr/bin/python 7143MiB |
| 3 30920 C /usr/bin/python 7151MiB |
+-----------------------------------------------------------------------------+注:在Distributed模式下,相当于你的代码分别在多个GPU上独立运行,代码都是设备无关的。比如你写
t = torch.zeros(100, 100).cuda(),在4个进程上运行的程序会分别在4个GPU上初始化t。所以显存的占用会是均匀的。
📝 通俗解释:想象4个人各自负责搬运一个仓库的货物,理论上每个人应该搬同样多的货物(显存均匀分布)。但有时候会出现一个人(GPU 0)帮其他人也搬了一部分,导致自己累得不行(显存占用高),其他人却比较轻松。
然而,有时会发现另外几个进程会在0卡上占一部分显存,导致0卡显存出现瓶颈,可能会导致cuda-out-of-memory错误。比如这样:
GPU状态表:
| GPU | Name | Memory-Usage | GPU-Util |
|---|---|---|---|
| 0 | GeForce RTX 208... | 10846MiB / 11019MiB | 92% |
| 1 | GeForce RTX 208... | 7169MiB / 11019MiB | 94% |
| 2 | GeForce RTX 208... | 7157MiB / 11019MiB | 92% |
| 3 | GeForce RTX 208... | 7159MiB / 11019MiB | 93% |
进程列表:
| GPU | PID | Process name | GPU Memory Usage |
|---|---|---|---|
| 0 | 31570 | /usr/bin/python | 7235MiB |
| 0 | 31571 | /usr/bin/python | 1199MiB |
| 0 | 31572 | /usr/bin/python | 1199MiB |
| 0 | 31573 | /usr/bin/python | 1199MiB |
| 1 | 31571 | /usr/bin/python | 7159MiB |
| 2 | 31572 | /usr/bin/python | 7147MiB |
| 3 | 31573 | /usr/bin/python | 7149MiB |
可以看到,GPU 0上跑了4个Python进程,而其他GPU各只跑了1个进程。
问题定位
该问题主要由以下代码导致:
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])注:上述代码运行后,程序load一个预训练模型的时候,
torch.load()会默认把load进来的数据放到0卡上,这样4个进程全部会在0卡占用一部分显存。
📝 通俗解释:就像你让4个人去仓库取货(加载预训练模型),但仓库只认第一个来的人(默认使用GPU 0),所以其他3个人取的东西也被记到了第一个人的账上,导致第一个人需要搬运的货物特别多(显存占用高)。
解决方法
把load进来的数据map到CPU上:
checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])📝 通俗解释:相当于让所有人先从公共仓库(CPU内存)取货,然后再各自从公共仓库搬到自己的仓库(对应的GPU),这样就不会出现一个人帮所有人搬货的情况了。
二、使用PyTorch实现同步梯度更新时,自研数据接口导致第一个epoch结尾处程序卡死问题
如果使用PyTorch实现同步梯度更新,数据接口是自己编写的话,一定要注意保证每张卡分配的batch数量是一样的。因为如果某张卡少了一个batch的话,其他卡就会等待,从而程序卡在torch.all_reduce()上。最后的情况就会出现在第一个epoch结尾处程序卡住,而且没有报错信息。
📝 通俗解释:想象4个人搬货物,每个人需要搬10趟(10个batch)。如果其中一个人只搬了9趟就停了,其他3个人搬完自己的10趟后,发现还有货物没人搬,但又不知道那个人去哪了,就会一直等着他,造成所有人都卡住不动了。
torch.all_reduce()就是那个让大家同步的工具,但它不知道有人提前跑了。
📝 通俗解释:
torch.all_reduce()是分布式计算中的一个同步操作,类似于所有人必须到齐才能继续下一步的"点名"机制。如果有人没到(某个GPU的batch还没处理完),其他人就会一直等下去。
三、微调大模型时,单机2卡正常但4卡及以上卡住问题
在微调大模型的时候,单机2卡的时候正常训练,但是采用4卡及以上,就会卡住,卡在读完数据和开始训练之间。
排查步骤
- 确认GPU通信正常:先确认几张卡都能正常使用和通信
- 检查batchsize分配:看看是不是batchsize分配之类的问题导致无限等待某一张卡了
- 最小化测试:只留4条数据,每张卡只跑一条数据试试看
📝 通俗解释:
- 步骤1:就像检查4个人之间能不能正常喊话(通信)
- 步骤2:看看是不是分配任务的时候不均匀,有人太多有人太少
- 步骤3:简化问题,就像让每个人只搬一件货,看看能不能正常完成
常见原因
- 数据并行问题:不同GPU处理的数据量不一致
- 同步等待:某个GPU处理速度慢,导致其他GPU等待
- 模型复制问题:模型没有正确复制到所有GPU上
📝 总结:PyTorch分布式计算主要有三大坑:
- 显存不均衡 → 加载模型时指定CPU
- 自研数据接口导致卡死 → 保证每张卡batch数相同
- 多卡训练卡住 → 排查通信、batch分配、模型复制问题
参考来源:知识星球 - AiGC面试宝典