显存优化策略篇
来源:AiGC面试宝典 整理:宁静致远 日期:2024年01月27日
一、介绍一下 gradient accumulation 显存优化方式?
基本概念
梯度累积(Gradient Accumulation)是深度学习训练中的一种技术,用于在一次反向传播中累积多个小批量数据的梯度,然后一次性更新模型参数。这个技术的主要目的是在内存有限的情况下,能够有效地使用大批量数据进行训练,从而提高模型性能。
📝 通俗解释:想象一下你要搬一堆砖头,一次只能搬5块,但实际需要搬50块才能完成工作。梯度累积就像是把每次搬的5块砖放在一个临时堆放点,等凑齐10次(50块)后再一次性把砖头放到目的地。这样既不占用额外空间,又能完成大批量的工作。
背景
在深度学习中,通常使用小批量随机梯度下降(Mini-batch SGD)来训练模型。每个小批量数据都会计算一次梯度,并用这个梯度来更新模型参数。然而,由于显存(GPU内存)的限制,无法一次性处理大批量数据。这会限制模型的批量大小,从而影响训练效率和性能。
📝 通俗解释:就像你的电脑内存只有8GB,但训练一个大型AI模型需要16GB内存才能顺利运行。直接运行会内存不足报错,这时候就需要用到显存优化技术。
原理
梯度累积的基本思想是:将多个小批量数据的梯度累积起来,然后一次性更新模型参数。具体操作是:
- 对于每个小批量数据,计算其梯度
- 将这些梯度累积在一起(不清零)
- 当累积的梯度达到一定数量时(累积步数),才执行一次参数更新
📝 通俗解释:原本每个batch计算完梯度就马上更新参数(每次吃一口饭就消化)。现在改成先计算多个batch的梯度累加起来,等积累到一定量再统一更新参数(先把几顿饭存在胃里一起消化)。
作用
- 内存效率:梯度累积允许在内存有限的情况下使用更大的批量数据进行训练。虽然每个小批量的梯度会被累积,但累积过程不会占用额外的内存空间,因此可以充分利用计算资源。
📝 通俗解释:在不增加内存占用的情况下,实现了"假装"使用大batch进行训练的效果。
- 训练稳定性:大批量数据包含更全面和丰富的信息,可以减少梯度的方差,从而提供更稳定的梯度信号,有助于更快地收敛。
📝 通俗解释:数据量越大,梯度方向越准确,就像做市场调研时,调查1000人的结果比调查10人更可靠、更稳定。
- 参数更新频率控制:通过设置累积步数,可以灵活控制参数更新的频率,以适应不同的硬件限制和训练需求。
📝 通俗解释:可以通过调整累积步数来控制"多久更新一次参数",像调节水龙头的开关一样控制训练速度。
代码实现
传统梯度更新方式(每个batch都更新参数):
for (inputs, labels) in data_loader:
# 提取输入和标签
inputs = inputs.to(device)
labels = labels.to(device)
# 前向传播
with torch.set_grad_enabled(True):
preds = model(inputs)
loss = criterion(preds, labels)
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
optimizer.zero_grad()梯度累积方式(累积指定步数后再更新):
gradient_accumulation_steps = 4
for batch_idx, (inputs, labels) in enumerate(data_loader):
# 提取输入和标签
inputs = inputs.to(device)
labels = labels.to(device)
# 前向传播
with torch.set_grad_enabled(True):
preds = model(inputs)
loss = criterion(preds, labels)
# 对loss进行缩放(取平均)
loss /= gradient_accumulation_steps
# 反向传播(计算梯度)
loss.backward()
# 每隔gradient_accumulation_steps个batch或到达最后一个batch时更新参数
if ((batch_idx + 1) % gradient_accumulation_steps == 0) or ((batch_idx + 1) == len(data_loader)):
optimizer.step()
optimizer.zero_grad()📝 通俗解释:关键在于
loss /= gradient_accumulation_steps,这行代码确保最终累积的梯度相当于4个batch的平均梯度,这样更新效果和真正使用4倍batch_size是一样的。
注意事项
- 较大的累积步数可能导致更新频率过低,从而降低训练速度
- 累积梯度可能会影响动量和学习率等参数的计算
- 需要根据具体情况进行参数的设置和调整
二、介绍一下 gradient checkpointing 显存优化方式?
基本概念
梯度检查点(Gradient Checkpointing)是一种优化深度学习模型训练中内存使用的技术。它通过在模型的计算图中插入检查点,将一部分计算推迟到后续步骤进行,从而减少内存占用。这有助于训练更大、更深的模型,以及使用更大批量的数据。
📝 通俗解释:就像你做数学题时,不需要把每一步的计算结果都写在草稿纸上存着(太占地方),而是只记住关键的"检查点"结果。等需要回头检查时,再从最近的检查点重新算一遍。这样可以节省大量内存空间。
背景
深度神经网络通常包含许多层和参数。模型训练需要计算前向传播来获得预测结果,然后通过反向传播来计算梯度以进行参数更新。在大规模模型和数据集上训练时,这些计算会占用大量内存,尤其是在反向传播阶段需要存储大量的中间激活值(因为需要用这些中间值来计算梯度)。
📝 通俗解释:训练一个大型AI模型就像炒一盘复杂的菜,需要很多步骤的中间结果。如果把所有中间步骤的结果都记住(占内存),内存很快就不够用了。
原理
梯度检查点的核心思想是**"用时间换空间"**:
- 分阶段存储:将计算图分成多个段,每段只保存关键结果(检查点)
- 重新计算:在反向传播时,从最近的检查点重新计算所需的中间结果,而不是全部保存
- 减少内存:通过牺牲一部分计算时间来大幅减少内存占用
📝 通俗解释:类似于"记笔记"和"重新推导"的区别。不记详细笔记(省内存),但需要时可以从关键公式重新推导(多花时间)。
作用
- 减少内存压力:通过推迟部分计算和存储中间结果,梯度检查点技术降低了反向传播过程中所需的内存量,使得可以在有限硬件资源下训练更大的模型。
📝 通俗解释:原本需要16GB显存才能训练的大模型,现在可能只需要8GB就能训练。
- 支持大批量训练:使用大批量数据训练可以加快收敛速度和稳定性,但会占用大量内存。梯度检查点允许在大批量训练中有效使用内存。
📝 通俗解释:原来只能batch_size为8训练,现在可以用batch_size为32训练。
- 控制内存开销:允许在内存和计算资源有限的情况下进行实验和研究,无需投入昂贵的硬件。
📝 通俗解释:让普通显卡也能训练大模型,降低了科研门槛。
PyTorch实现示例
from torch.utils.checkpoint import checkpoint_sequential
# 方法1:使用checkpoint_sequential
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layers = nn.ModuleList([
nn.Linear(256, 256),
nn.Linear(256, 256),
nn.Linear(256, 256),
# ... 更多层
])
def forward(self, x):
# 将层分成多个模块,每个模块作为一个检查点
modules = nn.ModuleList(self.layers[i:i+2] for i in range(0, len(self.layers), 2))
return checkpoint_sequential(modules, len(modules), x)
# 方法2:对单个层使用checkpoint
from torch.utils.checkpoint import checkpoint
def forward_with_checkpoint(x):
# 对计算量大的层使用checkpoint
x = checkpoint(model.layer1, x)
x = model.layer2(x)
return x📝 通俗解释:代码中
checkpoint_sequential将模型分成若干段,每段只保存关键结果,段内重新计算,从而节省显存。
注意事项
- 计算开销:梯度检查点会引入额外的重新计算开销,训练时间可能增加20%-30%
- 使用场景:适合模型特别大、内存不足的情况;如果内存足够,直接用更高效
- 与梯度累积结合:两种技术可以叠加使用,进一步节省内存
📝 通俗解释:这是典型的"空间换时间"策略。用多花一点计算时间(重新推导)来换取更少的内存占用,就像用硬盘空间换内存空间一样。
总结对比
| 特性 | Gradient Accumulation | Gradient Checkpointing |
|---|---|---|
| 核心思想 | 累积多个batch的梯度再更新 | 只存关键检查点,其余重新计算 |
| 作用 | 增大有效batch_size | 减少中间激活值内存占用 |
| 计算开销 | 基本无额外开销 | 增加20%-30%计算时间 |
| 适用场景 | 显存够但batch_size小 | 模型太大、显存不足 |
📝 通俗解释:这两种技术就像两个不同的"省内存"方法。梯度累积是"少更新几次",梯度检查点是"少记中间步骤"。可以根据实际情况单独使用或组合使用。
整理完成