Skip to content

显存优化策略篇

来源:AiGC面试宝典 整理:宁静致远 日期:2024年01月27日


一、介绍一下 gradient accumulation 显存优化方式?

基本概念

梯度累积(Gradient Accumulation)是深度学习训练中的一种技术,用于在一次反向传播中累积多个小批量数据的梯度,然后一次性更新模型参数。这个技术的主要目的是在内存有限的情况下,能够有效地使用大批量数据进行训练,从而提高模型性能。

📝 通俗解释:想象一下你要搬一堆砖头,一次只能搬5块,但实际需要搬50块才能完成工作。梯度累积就像是把每次搬的5块砖放在一个临时堆放点,等凑齐10次(50块)后再一次性把砖头放到目的地。这样既不占用额外空间,又能完成大批量的工作。

背景

在深度学习中,通常使用小批量随机梯度下降(Mini-batch SGD)来训练模型。每个小批量数据都会计算一次梯度,并用这个梯度来更新模型参数。然而,由于显存(GPU内存)的限制,无法一次性处理大批量数据。这会限制模型的批量大小,从而影响训练效率和性能。

📝 通俗解释:就像你的电脑内存只有8GB,但训练一个大型AI模型需要16GB内存才能顺利运行。直接运行会内存不足报错,这时候就需要用到显存优化技术。

原理

梯度累积的基本思想是:将多个小批量数据的梯度累积起来,然后一次性更新模型参数。具体操作是:

  1. 对于每个小批量数据,计算其梯度
  2. 将这些梯度累积在一起(不清零)
  3. 当累积的梯度达到一定数量时(累积步数),才执行一次参数更新

📝 通俗解释:原本每个batch计算完梯度就马上更新参数(每次吃一口饭就消化)。现在改成先计算多个batch的梯度累加起来,等积累到一定量再统一更新参数(先把几顿饭存在胃里一起消化)。

作用

  1. 内存效率:梯度累积允许在内存有限的情况下使用更大的批量数据进行训练。虽然每个小批量的梯度会被累积,但累积过程不会占用额外的内存空间,因此可以充分利用计算资源。

📝 通俗解释:在不增加内存占用的情况下,实现了"假装"使用大batch进行训练的效果。

  1. 训练稳定性:大批量数据包含更全面和丰富的信息,可以减少梯度的方差,从而提供更稳定的梯度信号,有助于更快地收敛。

📝 通俗解释:数据量越大,梯度方向越准确,就像做市场调研时,调查1000人的结果比调查10人更可靠、更稳定。

  1. 参数更新频率控制:通过设置累积步数,可以灵活控制参数更新的频率,以适应不同的硬件限制和训练需求。

📝 通俗解释:可以通过调整累积步数来控制"多久更新一次参数",像调节水龙头的开关一样控制训练速度。

代码实现

传统梯度更新方式(每个batch都更新参数):

python
for (inputs, labels) in data_loader:
    # 提取输入和标签
    inputs = inputs.to(device)
    labels = labels.to(device)

    # 前向传播
    with torch.set_grad_enabled(True):
        preds = model(inputs)
        loss = criterion(preds, labels)

        # 反向传播
        loss.backward()

        # 参数更新
        optimizer.step()
        optimizer.zero_grad()

梯度累积方式(累积指定步数后再更新):

python
gradient_accumulation_steps = 4

for batch_idx, (inputs, labels) in enumerate(data_loader):
    # 提取输入和标签
    inputs = inputs.to(device)
    labels = labels.to(device)

    # 前向传播
    with torch.set_grad_enabled(True):
        preds = model(inputs)
        loss = criterion(preds, labels)
        
        # 对loss进行缩放(取平均)
        loss /= gradient_accumulation_steps

        # 反向传播(计算梯度)
        loss.backward()

    # 每隔gradient_accumulation_steps个batch或到达最后一个batch时更新参数
    if ((batch_idx + 1) % gradient_accumulation_steps == 0) or ((batch_idx + 1) == len(data_loader)):
        optimizer.step()
        optimizer.zero_grad()

📝 通俗解释:关键在于 loss /= gradient_accumulation_steps,这行代码确保最终累积的梯度相当于4个batch的平均梯度,这样更新效果和真正使用4倍batch_size是一样的。

注意事项

  1. 较大的累积步数可能导致更新频率过低,从而降低训练速度
  2. 累积梯度可能会影响动量和学习率等参数的计算
  3. 需要根据具体情况进行参数的设置和调整

二、介绍一下 gradient checkpointing 显存优化方式?

基本概念

梯度检查点(Gradient Checkpointing)是一种优化深度学习模型训练中内存使用的技术。它通过在模型的计算图中插入检查点,将一部分计算推迟到后续步骤进行,从而减少内存占用。这有助于训练更大、更深的模型,以及使用更大批量的数据。

📝 通俗解释:就像你做数学题时,不需要把每一步的计算结果都写在草稿纸上存着(太占地方),而是只记住关键的"检查点"结果。等需要回头检查时,再从最近的检查点重新算一遍。这样可以节省大量内存空间。

背景

深度神经网络通常包含许多层和参数。模型训练需要计算前向传播来获得预测结果,然后通过反向传播来计算梯度以进行参数更新。在大规模模型和数据集上训练时,这些计算会占用大量内存,尤其是在反向传播阶段需要存储大量的中间激活值(因为需要用这些中间值来计算梯度)。

📝 通俗解释:训练一个大型AI模型就像炒一盘复杂的菜,需要很多步骤的中间结果。如果把所有中间步骤的结果都记住(占内存),内存很快就不够用了。

原理

梯度检查点的核心思想是**"用时间换空间"**:

  1. 分阶段存储:将计算图分成多个段,每段只保存关键结果(检查点)
  2. 重新计算:在反向传播时,从最近的检查点重新计算所需的中间结果,而不是全部保存
  3. 减少内存:通过牺牲一部分计算时间来大幅减少内存占用

📝 通俗解释:类似于"记笔记"和"重新推导"的区别。不记详细笔记(省内存),但需要时可以从关键公式重新推导(多花时间)。

作用

  1. 减少内存压力:通过推迟部分计算和存储中间结果,梯度检查点技术降低了反向传播过程中所需的内存量,使得可以在有限硬件资源下训练更大的模型。

📝 通俗解释:原本需要16GB显存才能训练的大模型,现在可能只需要8GB就能训练。

  1. 支持大批量训练:使用大批量数据训练可以加快收敛速度和稳定性,但会占用大量内存。梯度检查点允许在大批量训练中有效使用内存。

📝 通俗解释:原来只能batch_size为8训练,现在可以用batch_size为32训练。

  1. 控制内存开销:允许在内存和计算资源有限的情况下进行实验和研究,无需投入昂贵的硬件。

📝 通俗解释:让普通显卡也能训练大模型,降低了科研门槛。

PyTorch实现示例

python
from torch.utils.checkpoint import checkpoint_sequential

# 方法1:使用checkpoint_sequential
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(256, 256),
            nn.Linear(256, 256),
            nn.Linear(256, 256),
            # ... 更多层
        ])
    
    def forward(self, x):
        # 将层分成多个模块,每个模块作为一个检查点
        modules = nn.ModuleList(self.layers[i:i+2] for i in range(0, len(self.layers), 2))
        return checkpoint_sequential(modules, len(modules), x)

# 方法2:对单个层使用checkpoint
from torch.utils.checkpoint import checkpoint

def forward_with_checkpoint(x):
    # 对计算量大的层使用checkpoint
    x = checkpoint(model.layer1, x)
    x = model.layer2(x)
    return x

📝 通俗解释:代码中 checkpoint_sequential 将模型分成若干段,每段只保存关键结果,段内重新计算,从而节省显存。

注意事项

  1. 计算开销:梯度检查点会引入额外的重新计算开销,训练时间可能增加20%-30%
  2. 使用场景:适合模型特别大、内存不足的情况;如果内存足够,直接用更高效
  3. 与梯度累积结合:两种技术可以叠加使用,进一步节省内存

📝 通俗解释:这是典型的"空间换时间"策略。用多花一点计算时间(重新推导)来换取更少的内存占用,就像用硬盘空间换内存空间一样。


总结对比

特性Gradient AccumulationGradient Checkpointing
核心思想累积多个batch的梯度再更新只存关键检查点,其余重新计算
作用增大有效batch_size减少中间激活值内存占用
计算开销基本无额外开销增加20%-30%计算时间
适用场景显存够但batch_size小模型太大、显存不足

📝 通俗解释:这两种技术就像两个不同的"省内存"方法。梯度累积是"少更新几次",梯度检查点是"少记中间步骤"。可以根据实际情况单独使用或组合使用。


整理完成

基于 MIT 许可发布