图解分布式训练(四)—— torch.multiprocessing 详细解析
来源:AiGC面试宝典
作者:宁静致远
日期:2023年09月29日 11:27
一、torch.multiprocessing 函数介绍一下?
torch.multiprocessing 是 Python 标准库 multiprocessing 模块的封装。它注册了自定义的 reducer(归约器),使用共享内存为不同进程中的相同数据提供视图共享。一旦张量/存储被移动到 shared_memory(参见 share_memory()),就可以将其发送到其他进程而无需复制数据。
📝通俗解释:想象一下,你有一本珍贵的书(张量),普通情况下你要把书借给朋友,需要整本复印一份。但有了 torch.multiprocessing,就像是有了一本"魔法书",你和朋友可以同时看同一本书,内容完全同步,而且不需要额外复印。这就是"共享内存"的原理。
这个 API 与原始模块 100% 兼容,只需将 import multiprocessing 改为 import torch.multiprocessing,就可以将所有张量通过队列发送或通过其他共享机制转移到共享内存中。
📝通俗解释:这就相当于你原本用普通手机打电话(multiprocessing),现在换成了同款式的智能手机(torch.multiprocessing),操作方式一模一样,但功能更强大了——可以直接使用共享内存这个"黑科技"。
由于 API 的相似性,我们没有重复记录这个软件包的大部分内容,建议您参考 Python multiprocessing 原始模块的文档。
📝通俗解释:就像你学会了骑自行车,学电动车就很简单,因为操作逻辑差不多。torch.multiprocessing 和 Python 原生的 multiprocessing 用法几乎一样,所以官方文档主要让你去看 Python 原版的。
警告:如果主进程突然退出(例如因为接收到的信号),Python multiprocessing 有时无法清理其子进程。这是一个已知的限制,所以如果您在中断解释器后看到任何资源泄漏,这可能意味着刚刚发生在您身上。
二、torch.multiprocessing 函数如何使用?
torch.multiprocessing.get_all_sharing_strategies()返回一组当前系统支持的共享策略。
📝通俗解释:这就像去餐厅吃饭,服务员告诉你今天有哪些套餐可选。get_all_sharing_strategies() 就是告诉你系统支持哪些"套餐"(共享策略)。
torch.multiprocessing.get_sharing_strategy()返回当前共享 CPU 张量的策略。
📝通俗解释:相当于查询"当前餐厅给你上了哪个套餐"。
torch.multiprocessing.set_sharing_strategy(new_strategy)设置共享 CPU 张量的策略。
参数:
new_strategy(str) — 所选策略的名称,应当是上面get_all_sharing_strategies()中系统支持的共享策略之一。
📝通俗解释:相当于告诉服务员"我要换套餐,改成 XXX 套餐"。不过一般用默认的就行,不需要手动切换。
三、介绍一下共享 CUDA 张量?
注意:仅支持 Python 3 中使用 spawn 或 forkserver 启动方法时才支持在进程之间共享 CUDA 张量。Python 2 中的 multiprocessing 只能使用 fork 创建子进程,并且不支持 CUDA 运行时。
📝通俗解释:这就好像你要把一个高科技设备(CUDA 张量)借给朋友,只有用正确的方法(Python 3 + spawn/forkserver)才能成功。如果用老办法(Python 2 + fork),设备可能会损坏。
警告:CUDA API 要求导出到其他进程的分配一直保持有效,只要它们被使用。您应该小心,确保您共享的 CUDA 张量不要超出作用域,只要有必要。这不应该是共享模型参数的问题,但传递其他类型的数据应该小心。请注意,此限制不适用于共享 CPU 内存。
📝通俗解释:想象你借给朋友一个临时储物柜(CUDA 张量),在你朋友还在用的时候,你不能把柜子要回来,否则朋友的东西会丢失。CPU 内存就没这个问题,你可以随时要回来。所以共享模型参数通常没问题,但共享其他数据时要小心。
四、介绍一下共享策略?
本节简要概述了不同的共享策略如何工作。请注意,它只适用于 CPU 张量 —— CUDA 张量将始终使用 CUDA API,因为它们是唯一的共享方式。
📝通俗解释:CPU 张量有多种"共享方式"可选,就像寄快递可以用顺丰、圆通、邮政等。但 CUDA 张量比较特殊,只能用"顺丰"(CUDA API),没得选。
4.1 文件描述符(file_descriptor)
注意:这是默认策略(除了不支持的 macOS 和 OS X)。
此策略将使用文件描述符作为共享内存句柄。当内存被移动到共享内存时,从 shm_open() 获取的文件描述符会被缓存,当它被发送到其他进程时,文件描述符也将被传送(例如通过 UNIX 套接字)。接收方也会缓存文件描述符并且使用 mmap 映射它,以获得对共享数据的视图。
📝通俗解释:想象你在银行开了个保险箱(共享内存),银行给你一张磁卡(文件描述符)。你可以把这张磁卡的"副本"通过快递(UNIX 套接字)寄给朋友,朋友用这张磁卡也能打开同一个保险箱。为了方便,你们都会把磁卡信息保存在自己家里(缓存),下次直接用。
请注意,如果要共享很多张量,则此策略将保留大量文件描述符,需要较长时间才能打开。如果您的系统对打开的文件描述符数量有限制,并且无法提高,您应该使用 file_system 策略。
📝通俗解释:相当于如果你的朋友要借很多很多保险箱,你得给他准备很多张磁卡。磁卡太多管理起来很麻烦,而且银行可能限制每个人最多开多少张卡。这时候就要换一种方式了。
4.2 文件系统(file_system)
该策略将使用给定的文件名 shm_open() 来标识共享内存区域。这具有不需要缓存文件描述符的优点,但同时容易发生共享内存泄漏。该文件创建后不能被删除,因为其他进程需要访问它以打开其视图。如果进程崩溃并且不调用存储析构函数,则文件将保留在系统中。这是非常严重的问题,因为它们在系统重新启动之前一直占用内存,或者需要手动释放。
📝通俗解释:这种方式就像给每个共享的"保险箱"起了个固定的名字(比如"张三的保险箱")。优点是不用保存磁卡(文件描述符),直接报名字就行。但缺点是如果有人用完保险箱忘了销毁记录,这个空保险箱会一直占着地方。
为了解决共享内存文件泄漏的问题,torch.multiprocessing 将产生一个守护进程 torch_shm_manager,它将与自己所在的进程组隔离,并且将跟踪所有共享内存分配。一旦所有连接到它的进程退出,它将等待一会儿,以确保不会有新的连接,并且将遍历该组分配的所有共享内存文件。如果发现任何文件仍然存在,就会释放掉它。我们已经测试了这种方法,并且证明对于各种故障是稳健的。不过,如果您的系统支持 file_descriptor 策略,我们不建议切换到该策略。
📝通俗解释:这就相当于银行雇了一个专门的清洁工(torch_shm_manager),每天检查有没有人借了保险箱但长期不用。如果发现"废弃"的保险箱,就会自动清理掉。不过如果能用磁卡方式(file_descriptor),还是尽量用这个,因为更可靠。
五、torch.multiprocessing 函数使用
下面展示一个采用多进程训练模型的例子:
# 使用多进程训练模型的示例
import torch
import torch.nn as nn
import torch.multiprocessing as mp
def train(model, data_loader, optimizer, loss_fn):
"""
训练函数,将在不同进程中运行
"""
for data, labels in data_loader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step() # 这将更新共享的参数
# 定义模型
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# 将模型参数移到共享内存
# 这是 'fork' 方法工作所必需的
model.share_memory()
# 创建多个进程
processes = []
num_processes = 4 # 进程数量
for i in range(num_processes):
# 创建进程,target 是目标函数,args 是传递给目标函数的参数
p = mp.Process(target=train, args=(model, None, None, None))
p.start()
processes.append(p)
# 等待所有进程完成
for p in processes:
p.join()📝通俗解释:这段代码展示了如何用 4 个进程同时训练一个模型。想象一下:
- 首先定义了一个"模型"(神经网络)
- 然后调用
share_memory()把模型放到"共享保险箱"里- 接着启动 4 个"工人"(进程),每人拿一份模型去训练
- 所有工人都能同时看到和更新同一个模型(因为模型是共享的)
- 最后等所有工人都干完活(p.join())
关键点说明:
model.share_memory():将模型的所有参数移到共享内存,这是使用 fork 方法所必需的mp.Process:创建一个新的进程p.start():启动进程p.join():等待进程完成
📝通俗解释:
share_memory()就像是把模型从你家里(普通内存)搬到了"公共仓库"(共享内存),这样 4 个工人都能同时访问和修改它,而不需要每个人都复制一份。