大模型的分布式训练策略与ZeRO优化

描述

模型并行与数据并行:分布式训练策略与 ZeRO 优化

适用读者:要在多卡/多机上训练 7B-70B 大模型的工程师。本文不讲 NCCL 源码,只讲:每种并行策略(DP/DDP/TP/PP/ZeRO/FSDP)适合什么规模、显存怎么算、通信开销多大、配置怎么写、坑在哪里。读完你应该能自己拍板"我的 70B 模型在 8 卡 A100 上该用哪种并行"。

引子:那个 70B 模型"理论上能跑"的悲剧

去年我们接了一个项目,要在 16 块 A100 80G 上微调一个 70B 模型。团队里有人上来就说"Full FT 上 DDP"——理由是"16 张 80G = 1.28TB 显存,70B 模型才 140G,绰绰有余"。

实际跑起来后问题接踵而至:

DDP 启动后第一个 step 就 OOM:模型权重 140GB + Adam 状态 560GB + 梯度 140GB + 激活值 100GB+,单卡实际峰值 800GB;

切到 DeepSpeed ZeRO-3 后能跑,但通信开销爆表——8 卡 DDP 训练 1 step 只要 2.5s,ZeRO-3 跨机变成 18s;

加了 TP=2 后通信反而更多,每个 step 还要做两次 all-reduce;

最后发现模型并行和 ZeRO 是两套独立优化,需要按规模选,而不是叠加。

这就是为什么这一篇要把"数据并行 / 模型并行 / 流水线并行 / ZeRO"这些容易混淆的概念讲透——它们的工程取舍完全不同,错一个配置 GPU 就白买了。

一、分布式训练到底解决什么问题

训练大模型的三大硬约束:

硬约束 现象 数字举例
显存墙 单卡塞不下模型+优化器+梯度+激活 70B FP16 + AdamW = 840GB
算力墙 单卡训练 1 epoch 要 30 天 70B 模型在 1 张 A100 上要 50 天/1 epoch
带宽墙 多卡通信成为瓶颈 ZeRO-3 在 100Gbps 网上通信开销占比 60%+

分布式训练本质是用多卡的资源换单卡解决不了的问题,但每种并行策略解的"问题"不一样:

策略 解决的核心问题 适合规模
DP/DDP 算力墙 模型单卡装得下,训练慢
TP(张量并行) 显存墙(层内切分) 单层都装不下(70B+)
PP(流水线并行) 显存墙(层间切分) 模型层数多、batch=1 都装不下
ZeRO-1/2/3 显存墙(数据并行的优化版) 优化器状态/梯度分片到多卡
FSDP 显存墙(PyTorch 原生 ZeRO-3) 等价 ZeRO-3,但和 PyTorch 生态更紧
3D Parallelism 显存 + 算力 + 通信 70B+ 多机训练,工业级方案

 一句话记忆:DP 解决"训练慢",TP/PP/ZeRO 解决"塞不下",3D 组合是工业级标配。

二、适合什么场景 / 不适合什么场景

规模 推荐方案 理由
7B 模型 + 单卡 24G LoRA(无分布式) 没必要上分布式
7B 模型 + 8 卡 A100 DDP 数据并行最简单
13B 模型 + 8 卡 A100 DDP + ZeRO-2 优化器分片省显存
70B 模型 + 8 卡 A100 ZeRO-3 + TP=2 显存 + 通信平衡
70B 模型 + 32 卡 H100 3D 并行(TP=8, PP=2, DP=2) 工业级
175B 模型 + 128 卡 H100 3D 并行 + ZeRO-3 + 序列并行 极致场景
CPU 多机 DeepSpeed Offload 不推荐,IO 是瓶颈

经验公式:

模型参数量 ≤ 单卡显存 × 0.5 → DDP 足够

模型参数量 ≤ 单卡显存 × 2 → 加 ZeRO-2

模型参数量 ≤ 单卡显存 × 4 → 加 ZeRO-3

模型参数量 > 单卡显存 × 4 → 必须 TP/PP + ZeRO-3

 雷点:很多人把"ZeRO-3 = 强"当作铁律,但ZeRO-3 通信开销很大。如果 ZeRO-2 能装下,永远别上 ZeRO-3。

三、整体架构:五种并行策略的本质

 

┌──────────────────────────────────────────────────────────────────────┐
│ 1. Data Parallel (DP/DDP) - 数据并行                                  │
│                                                                       │
│   GPU 0: [Model] [data shard 0]  ─┐                                   │
│   GPU 1: [Model] [data shard 1]   ├→ all-reduce(gradient) 同步        │
│   GPU 2: [Model] [data shard 2]   │                                   │
│   GPU 3: [Model] [data shard 3]  ─┘                                   │
│                                                                       │
│   每张卡都有一份完整模型,数据切分到多卡                                │
│   显存: N × (模型+优化器+梯度)         通信: all-reduce 一次/步       │
└──────────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────────┐
│ 2. Tensor Parallel (TP) - 张量并行                                    │
│                                                                       │
│   QKV 线性层被切到多卡:                                                │
│   GPU 0: Q[0:n/2]   K[0:n/2]   V[0:n/2]   ─┐                         │
│   GPU 1: Q[n/2:n]  K[n/2:n]  V[n/2:n]     ├→ all-reduce (forward)    │
│   ...                                       │                         │
│   GPU 0: out_proj[0:n/2]                   │                         │
│   GPU 1: out_proj[n/2:n]                  ─┘                         │
│                                                                       │
│   单层都被切分,需要频繁同步                                             │
│   显存: 模型/N  通信: 每层 2 次 all-reduce                            │
└──────────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────────┐
│ 3. Pipeline Parallel (PP) - 流水线并行                                 │
│                                                                       │
│   模型的连续几层放到一张卡:                                            │
│   GPU 0: layers 0-7    ──┐                                            │
│   GPU 1: layers 8-15    ├──→ 微批次 pipeline 通信                     │
│   GPU 2: layers 16-23   │   (micro-batch 切换)                        │
│   GPU 3: layers 24-31  ─┘                                             │
│                                                                       │
│   显存: 模型/N  通信: 点对点 (相对少)  代价: bubble 浪费              │
└──────────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────────┐
│ 4. ZeRO (Zero Redundancy Optimizer) - 数据并行的优化版                 │
│                                                                       │
│   同样数据并行,但把优化器状态/梯度/参数分片到多卡                      │
│   Stage 1: 分片优化器状态 → 显存省 4x                                │
│   Stage 2: + 分片梯度              → 显存省 8x                        │
│   Stage 3: + 分片参数              → 显存省 Nx (N=卡数)               │
│                                                                       │
│   通信: ZeRO-1/2 增加 ~1.5x,ZeRO-3 增加 ~2-3x                         │
└──────────────────────────────────────────────────────────────────────┘

┌──────────────────────────────────────────────────────────────────────┐
│ 5. FSDP (Fully Sharded Data Parallel) - PyTorch 原生版                │
│                                                                       │
│   等价于 ZeRO-3,完全在 PyTorch 内部实现                                │
│   优势: 和 PyTorch 生态完全兼容;                                        │
│   劣势: 早期版本比 DeepSpeed 慢,2.x 后差距已不大                      │
└──────────────────────────────────────────────────────────────────────┘

 

四、核心流程拆解

4.1 显存计算公式(必背)

训练时的显存主要由四部分构成:

 

总显存 = 模型权重 + 优化器状态 + 梯度 + 激活值
        ↓          ↓            ↓       ↓
        P          2P (Adam)    2P     (batch × seq × hidden × L)
        (FP16)    (FP32 m,v)   (FP16)   (中间结果)

 

具体数字(7B 模型,FP16 训练):

组件 计算 显存
模型权重 7B × 2 bytes 14GB
优化器 (AdamW, FP32 m+v) 7B × 8 bytes 56GB
梯度 (FP16) 7B × 2 bytes 14GB
激活值 (L=32, batch=4, seq=2048) 视具体值 ~20GB
总计   ~104GB

这就是为什么 7B 模型 Full FT 需要 A100 80G × 2 卡(即使不考虑激活)。

4.2 DDP 训练流程

输入:模型 + 数据 + 多卡 动作:

每张卡加载完整模型

数据分片到各卡(DistributedSampler)

前向 + 反向:各卡独立算梯度

同步梯度:all-reduce(最关键的一步)

各卡独立更新参数

异常处理:

通信失败:检查 NCCL 版本、网络(InfiniBand vs Ethernet)

负载不均:检查 DistributedSampler 是否正确设了 shuffle=False 的验证集

收敛慢:检查 learning rate 是否需要随 batch_size 线性放大

4.3 ZeRO 配置决策

 

7B 模型 + 24G 单卡 → 上 ZeRO-2 吗?
├─ 不需要,LoRA 就够
└─ 强行上 ZeRO-2? 通信开销不划算

70B 模型 + 40G 单卡 × 8 → 上哪个?
├─ ZeRO-3: 70B × 2 / 8 = 17.5GB/卡,装得下
└─ ZeRO-3 + TP=2: 通信更省,但配置复杂

 

经验值:

模型规模 推荐
< 13B ZeRO-1 / ZeRO-2
13B-70B ZeRO-3 + TP=2
70B-175B 3D 并行

4.4 通信开销(容易被低估)

策略 通信量/步 受什么限制
DDP 2×模型大小 带宽
ZeRO-1 1.5× 带宽
ZeRO-2 1.5× 带宽
ZeRO-3 2-3× 带宽 + 延迟
TP 每层 2×all-reduce 延迟敏感
PP 点对点 (N-1)/N bubble

关键洞察:TP 受延迟影响大,ZeRO 受带宽影响大。InfiniBand 环境下 ZeRO-3 表现好,普通以太网下 TP 优于 ZeRO-3。

五、关键代码:四种方案的最小可跑示例

5.1 DDP(最基础,多机多卡)

启动:

 

torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 
  --master_addr=192.168.1.1 --master_port=29500 
  train_ddp.py

 

代码:

 

# train_ddp.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

def main():
    # ===== 1. 初始化进程组 =====
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # ===== 2. 模型 + 数据 =====
    model = MyModel().cuda(local_rank)
    model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
    # 混合精度
    scaler = torch.cuda.amp.GradScaler()

    dataset = MyDataset()
    sampler = DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=4)

    # ===== 3. 训练循环 =====
    for epoch in range(3):
        sampler.set_epoch(epoch)  # 关键:每 epoch 重设 seed
        for batch in loader:
            with torch.cuda.amp.autocast():
                loss = model(batch)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

 

DDP 关键参数:

find_unused_parameters=True:模型有未用子模块时(多任务/条件分支)必须开,代价是 10-20% 通信开销

gradient_as_bucket_view=True(PyTorch 2.0+ 默认):DDP 内部把梯度打包成桶做 all-reduce,比逐参数 all-reduce 快 30%

static_graph=True:模型结构不变时设 True,省 20% 通信

5.2 DeepSpeed ZeRO-3(70B 模型实战)

启动:

 

deepspeed --num_gpus=8 --num_nodes=2 train_zero3.py

 

ds_config_zero3.json:

 

{
  "bf16": {"enabled": true},
"optimizer": {
    "type": "AdamW",
    "params": {"lr": 1e-5, "betas": [0.9, 0.95], "weight_decay": 0.1}
},
"scheduler": {
    "type": "WarmupCosineLR",
    "params": {"warmup_min_ratio": 0.03, "warmup_num_steps": 200}
},
"zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,// 通信和计算重叠,提速 10-20%
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}

 

train_zero3.py:

 

import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer

# ===== 1. 加载模型 =====
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-72B-Instruct",
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-72B-Instruct")

# ===== 2. DeepSpeed 初始化 =====
ds_engine = deepspeed.initialize(
    model=model,
    config="ds_config_zero3.json",
    model_parameters=model.parameters(),
)

# ===== 3. 训练循环 =====
for step, batch in enumerate(loader):
    loss = ds_engine.model(**batch).loss
    ds_engine.backward(loss)
    ds_engine.step()

 

ZeRO-3 关键参数解释:

stage: 3:完整 ZeRO-3(参数 + 梯度 + 优化器全分片)

offload_optimizer/pin_memory: true:把分片数据卸载到 CPU,省 GPU 显存 30%+,代价是慢 20%

overlap_comm: true:通信和计算重叠,提速关键

stage3_prefetch_bucket_size: "auto":预取参数,减少通信等待

stage3_max_live_parameters: 1e9:控制每张卡上同时驻留的参数上限,防显存爆掉

ZeRO 选型决策:

模型 / 卡 ZeRO Stage 理由
7B / 24G LoRA 不上 ZeRO 没必要
13B / 24G ZeRO-2 优化器分片省 4x
70B / 40G ZeRO-3 必须分片参数
70B / 80G ZeRO-2 + 混合精度 ZeRO-2 装得下就别上 ZeRO-3

5.3 FSDP(PyTorch 原生,2024 后推荐)

FSDP 在 PyTorch 2.0+ 已经成熟,和 DeepSpeed ZeRO-3 性能持平,生态更好。

 

# train_fsdp.py
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

def main():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-70b-hf",
        torch_dtype=torch.bfloat16,
    )

    # ===== 关键: auto wrap policy =====
    # 把每个 LlamaDecoderLayer 包成一个 FSDP 单元
    wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={LlamaDecoderLayer},
    )

    # ===== 混合精度策略 =====
    mp_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    )

    # ===== FSDP 包装 =====
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # 等价 ZeRO-3
        mixed_precision=mp_policy,
        auto_wrap_policy=wrap_policy,
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        forward_prefetch=True,
        device_id=local_rank,
        limit_all_gathers=True,           # 限制同时 gather 数量,防显存爆
        use_orig_params=True,             # 必须,兼容 optimizer
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    for batch in loader:
        loss = model(**batch).loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

 

FSDP 关键参数:

sharding_strategy:

FULL_SHARD(默认):等价 ZeRO-3

SHARD_GRAD_OP:等价 ZeRO-2

NO_SHARD:等价 DDP

HYBRID_SHARD:节点内全分片,跨节点复制(多机训练的好选择)

auto_wrap_policy:决定哪些层独立分片。transformer_auto_wrap_policy 把每个 Transformer 层当一个 FSDP 单元,比整模型分片通信少 50%。

backward_prefetch=BACKWARD_PRE:反向时预取下一层参数,提速 10-15%

forward_prefetch=True:前向时预取下一层

limit_all_gathers=True:关键! 限制同时 gather 的层数,防 OOM

use_orig_params=True:保持原始参数引用,让 optimizer 能正常用

FSDP vs DeepSpeed 对比:

维度 FSDP DeepSpeed ZeRO-3
性能 PyTorch 2.0+ 持平 持平
生态兼容 原生 PyTorch 需要 ds_engine
调试 容易(标准 PyTorch) 难(黑盒)
CPU offload 支持 支持
Activation checkpointing 原生 集成
文档/社区 越来越好 老牌稳

 一句话记忆:新项目优先 FSDP,老项目用 DeepSpeed。两者性能差距已经很小(< 5%)。

5.4 3D 并行(70B+ 工业级)

TP + PP + ZeRO-3 组合,用 Megatron-LM 或 DeepSpeed + accelerate 实现。

 

# 3D 并行需要修改模型结构
# 这里以 accelerate + 简单 TP 为例
# 实际工业级用 Megatron-DeepSpeed

# accelerate 配置: configs/deepspeed_zero3_tp2.yaml
compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
  gradient_accumulation_steps: 16
  zero_optimization:
    stage: 3
    offload_optimizer: {device: cpu}
tensor_parallel:
  tp_size: 2# 张量并行度
pipeline_parallel:
  pp_size: 2# 流水线并行度
mixed_precision: bf16

 

3D 并行的实际通信(TP=2, PP=2, DP=4, 8 卡):

 

模型被切 4 份(2×2),数据被切 4 份
通信:
- TP: 每层 2 次 all-reduce(高频)
- PP: 4 次点对点(低频)
- ZeRO-3 DP: 4 卡间 all-gather/reduce-scatter(中频)

总通信开销: TP > ZeRO-3 > PP

 

3D 并行调参经验:

TP 优先 (TP=2-8):在节点内(NVLink 高速互联)

PP 次之 (PP=2-4):在节点间(带宽低)

DP/ZeRO 最后:剩下的卡

六、上线后如何评估训练效果

分布式训练出问题不容易排查,必须从一开始就有指标:

6.1 训练效率指标

指标 含义 健康值
Throughput tokens/秒/GPU 与硬件算力比例正相关
MFU Model FLOPs Utilization A100 上 40%-50%,H100 上 50%-60%
HTU Hardware FLOPs Utilization 同上
Step time 单步耗时 越稳定越好
通信占比 comm_time / step_time < 30% 为佳

测 throughput 代码:

 

# 训练 100 步后统计
start = time.time()
for step, batch in enumerate(loader):
    if step == 100:
        t = time.time() - start
        tokens_per_sec = 100 * batch_size * seq_len * num_gpus / t
        print(f"throughput: {tokens_per_sec:.0f} tokens/s")
        break

 

6.2 显存监控

 

# 在训练循环里加
print(f"GPU {rank} mem: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")

 

或者用 torch.cuda.memory_summary() 详细看:

allocated:当前分配

reserved:PyTorch 缓存

max_allocated:峰值

6.3 收敛性监控

指标 关注点
loss 曲线 是否和单卡 DDP 收敛一致
gradient norm 异常大/小都是问题
各卡 loss 同步性 DDP/PP 不同步是数据分片问题
eval 指标 跨 epoch 是否稳定提升

 雷点:分布式训练有时会"静默掉点"——loss 看起来正常但精度差 2-3%。必须做单卡 DDP baseline 对比。

七、常见坑与避雷(12 条实战总结)

坑 1:NCCL 超时

现象:RuntimeError: NCCL timeout。 根因:多机通信卡死或慢。 解决:

 

export NCCL_TIMEOUT=1800  # 30 分钟
export NCCL_IB_DISABLE=0  # 启用 InfiniBand
export NCCL_DEBUG=INFO   # 调试时打开,看哪个 rank 卡死

 

坑 2:CPU offload 把 IO 打爆

现象:ZeRO-3 + CPU offload 后训练极慢(每个 step 30s+)。 根因:CPU 内存带宽不够,PCIe 反复搬运数据。 解决:

去掉 offload,改用更多卡

用 NVMe SSD 做二级 offload(DeepSpeed 支持)

上 H100(96GB HBM3 减少 offload 需求)

坑 3:ZeRO-3 + 推理脚本不兼容

现象:ZeRO-3 训练完的 checkpoint 加载到推理脚本报错。 根因:ZeRO-3 训练时参数是分片的(每张卡只存 1/N),保存需要 gather。 解决:

 

# DeepSpeed 保存时
if ds_engine.global_rank == 0:
    ds_engine.save_checkpoint(save_dir)
# 加载时
ds_engine.load_checkpoint(load_dir)

 

坑 4:activation checkpointing 选错粒度

现象:开了 gradient checkpointing 后显存省了 50%,但速度慢 40%。 根因:粒度太细(每个线性层都 checkpoint),重算太多。 解决:

 

# 在 Transformers 模型里
model.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs={"use_reentrant": False}
)
# 选 "每个 Transformer 层" 一个 checkpoint,粒度最合适

 

坑 5:TP/PP 切错位置

现象:TP=2 训练时 loss 震荡,TP=4 时直接发散。 根因:TP 切分必须按维度对齐(attention head 数必须能整除 TP 度)。 解决:

num_attention_heads % tp_size == 0

hidden_size % tp_size == 0

不满足时改用其他并行方式

坑 6:FSDP 的 use_orig_params 漏掉

现象:FSDP 训练时 optimizer.step() 报错。 根因:FSDP 默认会重写参数,optimizer 拿不到原始引用。 解决:

 

model = FSDP(model, use_orig_params=True, ...)  # 必须

 

坑 7:gradient_accumulation_steps 配错

现象:loss 不下降或下降极慢。 根因:等效 batch size 太大(ZeRO 下被分片放大),学习率没跟上。 解决:

 

# DDP 下 batch=32,ZeRO-3 下等效 batch 仍是 32,但梯度更新更稀疏
# 学习率建议调小(10% 起步)

 

坑 8:数据并行 + 不平衡数据集

现象:某些 epoch 特别慢(卡之间负载不均)。 根因:DistributedSampler 按样本数切分,长样本全在一张卡。 解决:

 

sampler = DistributedSampler(
    dataset,
    shuffle=True,
    seed=42,
    drop_last=True,
)
# 用 GroupByLengthSampler 按长度分组

 

坑 9:混合精度选错

现象:FP16 训练时 loss 突然变 NaN。 根因:FP16 数值范围小,gradient overflow 频繁。 解决:

优先用 BF16(A100+ 支持,范围和 FP32 一样)

FP16 必须配 GradScaler(前面 DDP 代码)

检查 loss.item() 是否出现 inf/nan

坑 10:多机训练卡在 init

现象:torchrun 启动后所有进程 hang 在初始化。 根因:网络不通、端口被防火墙挡、共享文件系统问题。 解决:

 

# 测试网络
nc -zv master_addr 29500
# 关闭防火墙(或开放端口)
# 用 Gloo 调试(CPU)
dist.init_process_group(backend="gloo")

 

坑 11:TP 度对不上 hidden size

现象:TP=8 训练时某些层报 "shape mismatch"。 根因:hidden_size=4096, num_heads=32, head_dim=128;TP=8 时 head_dim / TP = 16(OK),但 QKV 合并层要求 hidden_size 能整除 TP。 解决:

TP 度只能是 head 数的因数

7B 模型 head=32 → TP 只能 1/2/4/8/16/32

13B 模型 head=40 → TP 只能 1/2/4/5/8/10/20/40

坑 12:checkpoint 保存太慢

现象:ZeRO-3 保存 70B checkpoint 要 30 分钟+。 根因:每张卡只有 1/N 参数,保存时需要 gather 到 rank 0。 解决:

 

# 用 stage3_gather_16bit_weights_on_model_save: true
# 只在 rank 0 触发 gather
# 用 safetensors 单文件格式

 

 最后一条:分布式训练最大的隐性成本不是 GPU,而是跨机调试。永远从单卡 DDP 调通,再上分布式。

八、优化方向(从"能用"到"好用")

按性价比从高到低:

混合精度 (BF16):零成本提速 30%-50%

Flash Attention 2:提速 30%,省显存 20%

Gradient Checkpointing:省显存 30%-50%,慢 20%

FSDP/ZeRO-2:把多卡的"数据并行"变成"数据并行 + 优化器分片"

Optimizer CPU Offload:省显存,慢 20%

Fused Optimizer(DeepSpeed):AdamW 改成 FusedAdam,省显存 30%

Communication Overlap:通信和计算重叠,提速 10%-20%

Sequence Parallelism:超长序列(> 8k)必备

ZeRO-Infinity / ZeRO++:极致优化版

3D Parallelism:终极方案

单卡训练 → DDP → ZeRO-2 → ZeRO-3 → TP+PP+ZeRO,按规模逐步升级。

何时该考虑别的方案:

单卡就能跑 → 别分布式,复杂度不值

8 卡 DDP 还慢 → 数据并行上限到了,上 TP/PP

模型 < 1B → 多半不需要分布式

推理性能问题 → 量化 + 蒸馏,不要用训练手段解决

九、收尾

分布式训练的核心是按规模选策略:

模型规模 推荐
< 7B 单卡 DDP 或 LoRA(前面文章讲过)
7B-13B DDP + ZeRO-2
13B-70B FSDP / ZeRO-3 + TP=2-4
70B+ 3D 并行(TP + PP + ZeRO)

三个最容易犯的错:

小模型用 ZeRO-3:通信开销大于收益

大模型用 DDP:直接 OOM

不调 baseline 就上分布式:90% 的"分布式训练"问题都是单卡调通后才能复现的

决策流程:

单卡 DDP 跑通 + 收敛 → 验证算法

加 ZeRO-2 / FSDP → 验证显存

加 TP/PP → 验证规模

多机多卡 → 验证网络

生产稳定 → 监控 + 弹性调度

什么时候该考虑别的方案?当分布式训练本身成了瓶颈——比如成本太高、训练太慢、调试太难——这时候该回头看是不是模型选太大了:能不能用更小的 base model + LoRA 微调(前面文章讲过)?能不能用蒸馏让大模型教小模型?大模型工程不是堆算力,而是"用最小的资源达到业务效果"——这一原则贯穿了前面三篇(微调、量化、并行),也是所有 AI 工程师最终要回到的朴素出发点。

 

 

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分