Skip to content

图解分布式训练(五)—— AMP混合精度训练 详细解析

来源:AI大模型面试宝典
作者:宁静致远
日期:2023年09月29日

一、为什么需要 AMP 混合精度训练?

PyTorch 1.6 版本正式发布了,其中带来的最大更新就是自动混合精度(Automatic Mixed Precision,AMP)。Release 说明的标题是:

Stable release of automatic mixed precision (AMP). New Beta features include a TensorPipe backend for RPC, memory profiler, and several improvements to distributed training for both RPC and DDP.

可见自动混合精度正是 PyTorch 1.6 的最大更新。这就带来了几个问题:

  • 什么是自动混合精度训练?
  • 为什么需要自动混合精度?
  • 如何在 PyTorch 中使用自动混合精度?

📝通俗解释:PyTorch 1.6 版本新增的 AMP 功能就像给训练过程装了一个"智能调度器",它能自动决定什么时候用高精度,什么时候用低精度,就像我们在计算时会选择用计算器还是心算一样。

二、什么是自动混合精度训练(AMP)

我们知道神经网络框架的计算核心是 Tensor,也就是那个从 scalar -> array -> matrix -> tensor 维度一路丰富过来的概念。在 PyTorch 中,我们可以这样创建一个 Tensor:

python
>>> import torch

>>> gemfield = torch.zeros(70, 30)
>>> gemfield.type()
'torch.FloatTensor'

>>> syszux = torch.Tensor([1, 2])
>>> syszux.type()
'torch.FloatTensor'

可以看到默认创建的 Tensor 都是 FloatTensor 类型。而在 PyTorch 中,一共有 10 种类型的 Tensor:

类型说明
torch.FloatTensor32位浮点型
torch.DoubleTensor64位浮点型
torch.HalfTensor16位浮点型(半精度)
torch.BFloat16Tensor16位浮点型(脑浮点)
torch.ByteTensor8位无符号整数
torch.CharTensor8位有符号整数
torch.ShortTensor16位有符号整数
torch.IntTensor32位有符号整数
torch.LongTensor64位有符号整数
torch.BoolTensor布尔型

📝通俗解释:Tensor 就是 PyTorch 中的数据容器,可以理解为多维数组。默认情况下,PyTorch 使用 32 位浮点数(FloatTensor),就像我们用普通的计算器一样,精度高但占用空间大。

自动混合精度的关键词有两个:自动、混合精度。 这是由 PyTorch 1.6 的 torch.cuda.amp 模块带来的:

python
from torch.cuda.amp import autocast as autocast

混合精度预示着有不止一种精度的 Tensor,那在 PyTorch 的 AMP 模块里是几种呢?2 种:torch.FloatTensor(32位)和 torch.HalfTensor(16位半精度)

📝通俗解释:混合精度就是"粗细结合"——有些计算用 32 位(精细),有些计算用 16 位(粗糙)。这样既能保证训练效果,又能提升速度和节省显存。

自动预示着 Tensor 的 dtype 类型会自动变化,也就是框架按需自动调整 Tensor 的 dtype(其实不是完全自动,有些地方还是需要手工干预)。

torch.cuda.amp 的名字意味着这个功能只能在 CUDA 上使用,事实上,这个功能正是 NVIDIA 的开发人员贡献到 PyTorch 项目中的。而只有支持 Tensor Core 的 CUDA 硬件才能享受到 AMP 的好处(比如 2080Ti 显卡)。Tensor Core 是一种矩阵乘累加的计算单元,每个 Tensor Core 每个时钟执行 64 个浮点混合精度操作(FP16 矩阵相乘和 FP32 累加),英伟达宣称使用 Tensor Core 进行矩阵运算可以轻易提速,同时降低一半的显存访问和存储。

📝通俗解释:Tensor Core 是 NVIDIA 显卡上的一种特殊计算单元,专门处理矩阵运算。它就像一个"专用计算器",专门做 16 位浮点数的矩阵乘法,然后用 32 位浮点数累加结果,速度特别快。

因此,在 PyTorch 中,当我们提到自动混合精度训练,我们说的就是在 NVIDIA 的支持 Tensor Core 的 CUDA 设备上使用 torch.cuda.amp.autocast(以及 torch.cuda.amp.GradScaler)来进行训练。

咦?为什么还要有 torch.cuda.amp.GradScaler

三、为什么需要自动混合精度?

这个问题其实暗含着这样的意思:为什么需要自动混合精度,也就是 torch.FloatTensor 和 torch.HalfTensor 的混合,而不全是 torch.FloatTensor?或者全是 torch.HalfTensor?

如果非要以这种方式问,那么答案只能是,在某些上下文中 torch.FloatTensor 有优势,在某些上下文中 torch.HalfTensor 有优势。答案进一步可以转化为,相比于之前的默认的 torch.FloatTensortorch.HalfTensor 有时具有优势,有时劣势不可忽视。

3.1 torch.HalfTensor 的优势

  • 存储小:占用的显存更少
  • 计算快:运算速度更快
  • 更好的利用 CUDA 设备的 Tensor Core:能充分发挥显卡的计算能力

因此训练的时候可以减少显存的占用(可以增加 batch size 了),同时训练速度更快。

📝通俗解释:想象一下,HalfTensor 就像是把高清照片压缩成略低分辨率的版本,虽然细节少了,但加载快、占地方小。对于训练深度学习模型来说,这就意味着能用更大的批次(batch size)来训练,而且跑得更快。

3.2 torch.HalfTensor 的劣势

  • 数值范围小:更容易出现 Overflow(溢出)和 Underflow(下溢)
  • 舍入误差:Rounding Error 导致一些微小的梯度信息达不到 16bit 精度的最低分辨率,从而丢失

📝通俗解释:HalfTensor 的问题就像压缩图片太狠后会失真一样——数字太小了会变成 0(梯度消失),数字太大了会变成无穷大(梯度爆炸)。这会导致训练不稳定,模型学不到东西。

3.3 解决方案

为了解决 torch.HalfTensor 的劣势,我们带来了两种解决方案:

  1. 梯度 scale(缩放):这正是 torch.cuda.amp.GradScaler 的作用,通过放大 loss 的值来防止梯度的 underflow(这只是 BP 的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再 unscale 回去);

  2. 回落到 torch.FloatTensor:这就是混合一词的由来。

那怎么知道什么时候用 torch.FloatTensor,什么时候用半精度浮点型呢?这是 PyTorch 框架决定的,在 PyTorch 1.6 的 AMP 上下文中,如下操作中 tensor 会被自动转化为半精度浮点型的 torch.HalfTensor

__matmul__, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul,
conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d,
linear, matmul, mm, mv, prelu

📝通俗解释:PyTorch 会自动判断哪些操作可以用半精度(比如矩阵乘法、卷积等),哪些必须用全精度(比如优化器更新参数)。这就像一个智能的"调度员",知道什么时候该用什么精度。

3.4 AMP 训练流程图

[图表描述] 图片展示了一个 AMP(自动混合精度)训练流程的示意图:

  • 左侧流程(权重更新):原始参数 parameter (float32) 经过 cast (float32 to float16) 转换为 parameter (float16)
  • 上方流程(前向传播)parameter (float16)activation (float16) 进入 forward compute,输出 activation (float16)
  • 右侧流程(反向传播)activation (float16)parameter/activation (float16) 进入 backward propagation,输出 activation_grad/weight_grad (float16)
  • 下方流程(梯度更新)activation_grad/weight_grad (float16) 经过 cast (float16 to float32) 转换为 activation_grad/weight_grad (float32),最后进入 parameter update 更新回 parameter (float32)

📝通俗解释:这个流程就像一个"中转站":前向和反向传播用 16 位计算(快),但参数更新时转回 32 位(精确)。这样既享受了半精度的好处,又避免了它的缺点。

四、混合精度训练的优点是什么?

  • 减少显存占用
  • 加快训练速度
  • 通信量减半,计算性能翻倍

📝通俗解释:使用混合精度训练,就像从开车改成坐高铁——同样的人和货物(模型数据),占用的空间更少(显存),跑得还更快(速度),而且过隧道时需要传递的东西也少了(通信量)。

五、混合精度训练的缺点是什么?

  • 数据溢出:由于数值范围小,容易出现数值溢出
  • 舍入误差:精度不足导致部分梯度信息丢失

📝通俗解释:缺点就像压缩文件一样——虽然方便,但可能会有细节丢失。如果压缩太狠(精度太低),重要的信息可能就找不回来了。

六、混合精度训练的关键技术是什么?

  • float32 主权重备份:保持一份 32 位精度的权重副本,确保参数更新时使用高精度
  • 动态损失缩放:通过动态调整损失缩放因子,防止梯度下溢

📝通俗解释:float32 主权重备份就像保留原件,只用复印件去做计算;而动态损失缩放就像给梯度"戴个放大镜",让微小的梯度也能被看清。

七、混合精度训练之动态损失缩放详解

7.1 什么是动态损失缩放?

动态损失缩放(Dynamic Loss Scaling)是一种防止梯度下溢的技术。具体做法是:

只需要将损失乘以某个大数字(如 1024),这会将梯度也放大 1024 倍,大大降低了梯度发生下溢的几率。计算出梯度后,只需将其除以 1024 就可以得到准确值。

📝通俗解释:就像我们要称一堆很轻的东西(微小梯度),直接称可能显示为 0。但如果把所有东西放在一起称(放大损失),就能称出重量了,称完再除以数量就知道单个的重量了。

7.2 动态选择损失标度

  • 发生溢出时:跳过优化器更新,损失标度减半
  • 连续 N 个 steps 没有发生溢出:损失标度翻倍

这种动态调整的机制既能保证训练稳定,又能尽可能大地使用缩放因子,从而获得更好的训练效果。

📝通俗解释:这就像开车时调节车速——如果发现要撞车了(梯度溢出),就减速(减小缩放因子);如果开了很久都没问题,就试着提速(增大缩放因子),直到找到最合适的速度。

7.3 动态损失缩放示意图

[图表描述] 图片展示了动态损失缩放的效果:

  • 左侧是一个直方图:横轴为 log(magnitude),纵轴为 Percentage of all activation gradient values。图中标注了 FP16 Representable range(FP16 可表示范围)、Become zero in FP16(在 FP16 中变为零)、FP16 denorms。

  • 右侧是一个折线图:标题为 loss_scale,横轴为步数(0, 1k, 2k... 10k),纵轴为数值(0, 1e+4, 2e+4, 3e+4)。曲线显示 loss_scale 随着步数增加呈现阶梯状下降,最后趋于平稳。

📝通俗解释:左边的图告诉我们,很多梯度在 FP16 里是"零"(看不见的);右边的图显示,loss_scale 会随着训练过程动态调整,保证训练既稳定又高效。

八、如何在 PyTorch 中使用自动混合精度?

答案就是 autocast + GradScaler

8.1 autocast 的使用

正如前文所说,需要使用 torch.cuda.amp 模块中的 autocast 类。使用也是非常简单的:

python
from torch.cuda.amp import autocast as autocast

# 创建 model,默认是 torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在 autocast 上下文之外
    loss.backward()
    optimizer.step()

可以使用 autocast 的 context managers 语义(如上所示),也可以使用 decorators 语义。当进入 autocast 的上下文后,上面列出来的那些 CUDA ops 会把 tensor 的 dtype 转换为半精度浮点型,从而在不损失训练精度的情况下加快运算。

📝通俗解释:autocast 就像一个"开关",打开后框架会自动把合适的计算换成半精度。只需要把前向传播的过程包在这个"开关"里就行,非常方便。

刚进入 autocast 的上下文时,tensor 可以是任何类型,你不要在 model 或者 input 上手工调用 .half(),框架会自动做,这也是自动混合精度中"自动"一词的由来。

另外一点就是,autocast 上下文应该只包含网络的前向过程(包括 loss 的计算),而不要包含反向传播,因为 BP 的 op 会使用和前向 op 相同的类型。

📝通俗解释:反向传播不需要放在 autocast 里,因为 PyTorch 会自动使用和前向传播相同的精度。

还有的时候,你的代码在 autocast 上下文中会报如下的错误:

RuntimeError: expected scalar type float but found c10::Half

对于这个错误,你可以在 tensor 上手工调用 .float() 来让 type 匹配。

8.2 GradScaler 的使用

但是别忘了前面提到的梯度 scaler 模块呀,需要在训练最开始之前实例化一个 GradScaler 对象。因此 PyTorch 中经典的 AMP 使用方式如下:

python
from torch.cuda.amp import autocast as autocast, GradScaler

# 创建 model,默认是 torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# 在训练最开始之前实例化一个 GradScaler 对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss. 为了梯度放大.
        scaler.scale(loss).backward()

        # scaler.step() 首先把梯度的值 unscale 回来.
        # 如果梯度的值不是 infs 或者 NaNs, 那么调用 optimizer.step() 来更新权重,
        # 否则,忽略 step 调用,从而保证权重不更新(不被破坏)
        scaler.step(optimizer)

        # 准备着,看是否要增大 scaler
        scaler.update()

📝通俗解释:GradScaler 就像一个"安全阀"——它会先放大梯度进行检查,如果发现梯度正常就更新参数,如果发现梯度异常(溢出)就跳过更新并减小缩放因子。这样既保证了训练速度,又保证了训练稳定性。

scaler 的大小在每次迭代中动态地估计,为了尽可能地减少梯度 underflow,scaler 应该更大;但是如果太大的话,半精度浮点型的 tensor 又容易 overflow(变成 inf 或者 NaN)。所以动态估计的原理就是在不出现 inf 或者 NaN 梯度值的情况下尽可能地增大 scaler 的值——在每次 scaler.step(optimizer) 中,都会检查是否有 inf 或 NaN 的梯度出现:

  • 如果出现了 inf 或者 NaN,scaler.step(optimizer) 会忽略此次的权重更新(optimizer.step()),并且将 scaler 的大小缩小(乘上 backoff_factor);
  • 如果没有出现 inf 或者 NaN,那么权重正常更新,并且当连续多次(growth_interval 指定)没有出现 inf 或 NaN,则 scaler.update() 会将 scaler 的大小增加(乘上 growth_factor)。

九、AMP 混合精度训练完整代码示例

9.1 Trainer 训练类

python
class Trainer:
    ...
    def train(self, train_loader, dev_loader=None, train_sampler=None):
        ...
        # 设置 AMP 混合精度训练
        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        
        if self.args.local_rank == 0:
            start = time.time()
        
        for epoch in range(1, self.args.epochs + 1):
            train_sampler.set_epoch(epoch)
            for step, batch_data in enumerate(train_loader):
                self.model.train()
                
                # 使用 AMP 混合精度训练
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        logits, label = self.on_step(batch_data)
                        loss = self.criterion(logits, label)
                        torch.distributed.barrier()
                        scaler.scale(loss).backward()
                        scaler.step(self.optimizer)
                        scaler.update()
                else:
                    logits, label = self.on_step(batch_data)
                    loss = self.criterion(logits, label)
                    torch.distributed.barrier()
                    loss.backward()
                    self.optimizer.step()

        if self.args.local_rank == 0:
            end = time.time()
            print("耗时:{}分钟".format((end - start) / 60))
        
        if not self.args.dev and self.args.local_rank == 0:
            torch.save(self.model.state_dict(), self.args.ckpt_path)

9.2 Args 参数类

python
class Args:
    ...
    local_rank = None
    local_world_size = None
    device_ids = None
    rank = None
    dev = False
    use_amp = True  # 开启 AMP 混合精度训练

9.3 main_worker 主函数

python
def main_worker(local_rank, local_world_size):
    # ==============================
    # 设置参数
    ...
    dist.init_process_group(backend="nccl", init_method="tcp://localhost:12345",
                            world_size=local_world_size, rank=local_rank)

    n = torch.cuda.device_count() // local_world_size
    device_ids = [local_rank]
    print(
        f"[{os.getpid()}] rank = {local_rank}, "
        + f"world_size = {local_world_size}, n = {n}, device_ids = {device_ids} \n", end=''
    )

    torch.cuda.set_device(local_rank)

    args = Args()
    args.local_world_size = local_world_size
    args.local_rank = local_rank
    args.device_ids = device_ids
    args.rank = local_rank
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    ...
    # 封装模型
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.device_ids)
    ...

9.4 AMP 混合精度训练完整代码

python
import os
import time
import json
import random
import torch
import torch.nn as nn
import numpy as np
import torch.distributed as dist
import torch.multiprocessing as mp

from collections import Counter
from tqdm import tqdm
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, Dataset
from transformers import BertForMaskedLM, BertTokenizer, BertForSequenceClassification, BertConfig, AdamW


def set_seed(seed=123):
    """
    设置随机数种子,保证实验可重现
    """
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_data():
    with open("data/train.json", "r", encoding="utf-8") as fp:
        data = fp.read()
    data = json.loads(data)
    return data


def load_data():
    data = get_data()
    return_data = []
    # [(文本, 标签id)]
    for d in data:
        text = d[0]
        label = d[1]
        return_data.append(("".join(text.split(" ")).strip(), label))
    return return_data


class ClsDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


class Collate:
    def __init__(self, tokenizer, max_seq_len):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def collate_fn(self, batch):
        input_ids_all = []
        token_type_ids_all = []
        attention_mask_all = []
        label_all = []
        
        for data in batch:
            text = data[0]
            label = data[1]
            inputs = self.tokenizer.encode_plus(
                text=text,
                max_length=self.max_seq_len,
                padding="max_length",
                truncation="longest_first",
                return_attention_mask=True,
                return_token_type_ids=True
            )
            input_ids = inputs["input_ids"]
            token_type_ids = inputs["token_type_ids"]
            attention_mask = inputs["attention_mask"]
            input_ids_all.append(input_ids)
            token_type_ids_all.append(token_type_ids)
            attention_mask_all.append(attention_mask)
            label_all.append(label)

        input_ids_all = torch.tensor(input_ids_all, dtype=torch.long)
        token_type_ids_all = torch.tensor(token_type_ids_all, dtype=torch.long)
        attention_mask_all = torch.tensor(attention_mask_all, dtype=torch.long)
        label_all = torch.tensor(label_all, dtype=torch.long)
        
        return_data = {
            "input_ids": input_ids_all,
            "attention_mask": attention_mask_all,
            "token_type_ids": token_type_ids_all,
            "label": label_all
        }
        return return_data


def build_optimizer(model, args):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
    return optimizer


class Trainer:
    def __init__(self, args, config, model, criterion, optimizer):
        self.args = args
        self.config = config
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

    def on_step(self, batch_data):
        label = batch_data["label"].cuda()
        input_ids = batch_data["input_ids"].cuda()
        token_type_ids = batch_data["token_type_ids"].cuda()
        attention_mask = batch_data["attention_mask"].cuda()
        
        output = self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=label
        )
        logits = output[1]
        return logits, label

    def loss_reduce(self, loss):
        rt = loss.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.args.local_world_size
        return rt

    def output_reduce(self, outputs, targets):
        output_gather_list = [torch.zeros_like(outputs) for _ in range(self.args.local_world_size)]
        dist.all_gather(output_gather_list, outputs)
        outputs = torch.cat(output_gather_list, dim=0)
        
        target_gather_list = [torch.zeros_like(targets) for _ in range(self.args.local_world_size)]
        dist.all_gather(target_gather_list, targets)
        targets = torch.cat(target_gather_list, dim=0)
        return outputs, targets

    def train(self, train_loader, dev_loader=None, train_sampler):
        global_step = 1
        best_acc = 0.0
        
        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        
        if self.args.local_rank == 0:
            start = time.time()
        
        for epoch in range(1, self.args.epochs + 1):
            train_sampler.set_epoch(epoch)
            for step, batch_data in enumerate(train_loader):
                self.model.train()
                
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        logits, label = self.on_step(batch_data)
                        loss = self.criterion(logits, label)
                        torch.distributed.barrier()
                        scaler.scale(loss).backward()
                        scaler.step(self.optimizer)
                        scaler.update()
                else:
                    logits, label = self.on_step(batch_data)
                    loss = self.criterion(logits, label)
                    torch.distributed.barrier()
                    loss.backward()
                    self.optimizer.step()
                
                if self.args.local_rank == 0:
                    print("【train】 epoch: {}/{} step: {}/{} loss: {:.6f}".format(
                        epoch, self.args.epochs, global_step, self.args.total_step, loss
                    ))
                
                global_step += 1
                
                if self.args.dev:
                    if global_step % self.args.eval_step == 0:
                        loss, accuracy = self.dev(dev_loader)
                        if self.args.local_rank == 0:
                            print("【dev】 loss: {:.6f} accuracy: {:.4f}".format(loss, accuracy))
                            if accuracy > best_acc:
                                best_acc = accuracy
                                print("【best accuracy】 {:.4f}".format(best_acc))
                                torch.save(self.model.state_dict(), self.args.ckpt_path)
        
        if self.args.local_rank == 0:
            end = time.time()
            print("耗时:{}分钟".format((end - start) / 60))
        
        if not self.args.dev and self.args.local_rank == 0:
            torch.save(self.model.state_dict(), self.args.ckpt_path)

    def dev(self, dev_loader):
        self.model.eval()
        correct_total = 0
        num_total = 0
        loss_total = 0.0
        
        with torch.no_grad():
            for step, batch_data in tqdm(enumerate(dev_loader)):
                logits, label = self.on_step(batch_data)
                loss = self.criterion(logits, label)
                torch.distributed.barrier()
                loss = self.loss_reduce(loss)
                loss_total += loss
                
                logits, label = self.output_reduce(logits, label)
                logits = logits.detach().cpu().numpy()
                label = label.view(-1).detach().cpu().numpy()
                num_total += len(label)
                preds = np.argmax(logits, axis=1).flatten()
                correct_num = (preds == label).sum()
                correct_total += correct_num
        
        return loss_total, correct_total / num_total

    def test(self, model, test_loader, labels):
        self.model = model
        self.model.eval()
        preds = []
        trues = []
        
        with torch.no_grad():
            for step, batch_data in enumerate(test_loader):
                logits, label = self.on_step(batch_data)
                torch.distributed.barrier()
                logits, label = self.output_reduce(logits, label)
                label = label.view(-1).detach().cpu().numpy().tolist()
                logits = logits.detach().cpu().numpy()
                pred = np.argmax(logits, axis=1).flatten().tolist()
                trues.extend(label)
                preds.extend(pred)
        
        report = classification_report(trues, preds, target_names=labels)
        return report


class Args:
    model_path = "/mnt/kaimo/data/pretrain/bert-base-chinese"
    ckpt_path = "output/multi-gpu-distributed-mp-amp-cls.pt"
    max_seq_len = 128
    ratio = 0.92
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_batch_size = 32
    dev_batch_size = 32
    weight_decay = 0.01
    epochs = 1
    learning_rate = 3e-5
    eval_step = 50
    local_rank = None
    local_world_size = None
    device_ids = None
    rank = None
    dev = False
    use_amp = True  # 开启 AMP 混合精度训练


def main_worker(local_rank, local_world_size):
    # 设置随机种子
    set_seed()
    
    label2id = {
        "其他": 0,
        "喜好": 1,
        "悲伤": 2,
        "厌恶": 3,
        "愤怒": 4,
        "高兴": 5,
    }
    
    # 初始化分布式训练
    dist.init_process_group(
        backend="nccl", 
        init_method="tcp://localhost:12345",
        world_size=local_world_size,
        rank=local_rank
    )

    n = torch.cuda.device_count() // local_world_size
    device_ids = [local_rank]
    print(
        f"[{os.getpid()}] rank = {local_rank}, "
        + f"world_size = {local_world_size}, n = {n}, device_ids = {device_ids} \n", end=''
    )

    torch.cuda.set_device(local_rank)

    args = Args()
    args.local_world_size = local_world_size
    args.local_rank = local_rank
    args.device_ids = device_ids
    args.rank = local_rank
    tokenizer = BertTokenizer.from_pretrained(args.model_path)

    # 加载数据集
    data = load_data()
    data = data[:10000]  # 取1万条数据
    random.shuffle(data)
    train_num = int(len(data) * args.ratio)
    train_data = data[:train_num]
    dev_data = data[train_num:]

    collate = Collate(tokenizer, args.max_seq_len)
    train_dataset = ClsDataset(train_data)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        num_workers=2,
        collate_fn=collate.collate_fn,
        sampler=train_sampler
    )
    total_step = len(train_loader) * args.epochs
    args.total_step = total_step
    
    dev_dataset = ClsDataset(dev_data)
    dev_sampler = torch.utils.data.distributed.DistributedSampler(dev_dataset)
    dev_loader = DataLoader(
        dev_dataset,
        batch_size=args.dev_batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=collate.collate_fn,
        sampler=dev_sampler
    )
    test_loader = dev_loader

    # 定义模型、优化器、损失函数
    config = BertConfig.from_pretrained(args.model_path, num_labels=6)
    model = BertForSequenceClassification.from_pretrained(args.model_path, config=config)
    
    # 封装模型
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.device_ids)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = build_optimizer(model, args)

    # 定义训练器,进行训练、验证和测试
    trainer = Trainer(args, config, model, criterion, optimizer)
    trainer.train(train_loader, dev_loader, train_sampler)

    # 测试
    labels = list(label2id.keys())
    config = BertConfig.from_pretrained(args.model_path, num_labels=6)
    model = BertForSequenceClassification.from_pretrained(args.model_path, config=config)
    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.device_ids)
    model.load_state_dict(torch.load(args.ckpt_path))
    report = trainer.test(model, test_loader, labels)
    
    if args.local_rank == 0:
        print(report)

    # 销毁进程组
    dist.destroy_process_group()


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_world_size", type=int, default=1)
    p_args = parser.parse_args()
    
    # 启动进程
    mp.spawn(main_worker, nprocs=p_args.local_world_size, args=(p_args.local_world_size,))

📝通俗解释:这是一个完整的分布式训练 + AMP 混合精度训练的示例。核心就是在训练循环中加入 torch.cuda.amp.autocast()GradScaler,就像给训练过程加上了"加速器",既快又省显存。

总结

关键技术说明
autocast自动将前向传播中的计算转换为半精度
GradScaler动态缩放梯度,防止下溢
float32 主权重保留一份高精度权重用于参数更新
动态损失缩放根据训练状态自动调整缩放因子

使用 AMP 混合精度训练可以带来以下好处:

  • 显存节省约 50%:可以训练更大的模型或使用更大的 batch size
  • 训练速度提升:利用 Tensor Core 加速矩阵运算
  • 通信量减半:减少分布式训练中的通信开销

📝通俗解释:AMP 混合精度训练就像给深度学习训练装上了"涡轮增压"——同样的硬件配置下,跑得更快、装的东西更多、还更省油。

基于 MIT 许可发布