模型并行与数据并行:分布式训练策略与 ZeRO 优化
适用读者:要在多卡/多机上训练 7B-70B 大模型的工程师。本文不讲 NCCL 源码,只讲:每种并行策略(DP/DDP/TP/PP/ZeRO/FSDP)适合什么规模、显存怎么算、通信开销多大、配置怎么写、坑在哪里。读完你应该能自己拍板"我的 70B 模型在 8 卡 A100 上该用哪种并行"。
引子:那个 70B 模型"理论上能跑"的悲剧
去年我们接了一个项目,要在 16 块 A100 80G 上微调一个 70B 模型。团队里有人上来就说"Full FT 上 DDP"——理由是"16 张 80G = 1.28TB 显存,70B 模型才 140G,绰绰有余"。
实际跑起来后问题接踵而至:
DDP 启动后第一个 step 就 OOM:模型权重 140GB + Adam 状态 560GB + 梯度 140GB + 激活值 100GB+,单卡实际峰值 800GB;
切到 DeepSpeed ZeRO-3 后能跑,但通信开销爆表——8 卡 DDP 训练 1 step 只要 2.5s,ZeRO-3 跨机变成 18s;
加了 TP=2 后通信反而更多,每个 step 还要做两次 all-reduce;
最后发现模型并行和 ZeRO 是两套独立优化,需要按规模选,而不是叠加。
这就是为什么这一篇要把"数据并行 / 模型并行 / 流水线并行 / ZeRO"这些容易混淆的概念讲透——它们的工程取舍完全不同,错一个配置 GPU 就白买了。
一、分布式训练到底解决什么问题
训练大模型的三大硬约束:
| 硬约束 | 现象 | 数字举例 |
|---|---|---|
| 显存墙 | 单卡塞不下模型+优化器+梯度+激活 | 70B FP16 + AdamW = 840GB |
| 算力墙 | 单卡训练 1 epoch 要 30 天 | 70B 模型在 1 张 A100 上要 50 天/1 epoch |
| 带宽墙 | 多卡通信成为瓶颈 | ZeRO-3 在 100Gbps 网上通信开销占比 60%+ |
分布式训练本质是用多卡的资源换单卡解决不了的问题,但每种并行策略解的"问题"不一样:
| 策略 | 解决的核心问题 | 适合规模 |
|---|---|---|
| DP/DDP | 算力墙 | 模型单卡装得下,训练慢 |
| TP(张量并行) | 显存墙(层内切分) | 单层都装不下(70B+) |
| PP(流水线并行) | 显存墙(层间切分) | 模型层数多、batch=1 都装不下 |
| ZeRO-1/2/3 | 显存墙(数据并行的优化版) | 优化器状态/梯度分片到多卡 |
| FSDP | 显存墙(PyTorch 原生 ZeRO-3) | 等价 ZeRO-3,但和 PyTorch 生态更紧 |
| 3D Parallelism | 显存 + 算力 + 通信 | 70B+ 多机训练,工业级方案 |
一句话记忆:DP 解决"训练慢",TP/PP/ZeRO 解决"塞不下",3D 组合是工业级标配。
二、适合什么场景 / 不适合什么场景
| 规模 | 推荐方案 | 理由 |
|---|---|---|
| 7B 模型 + 单卡 24G | LoRA(无分布式) | 没必要上分布式 |
| 7B 模型 + 8 卡 A100 | DDP | 数据并行最简单 |
| 13B 模型 + 8 卡 A100 | DDP + ZeRO-2 | 优化器分片省显存 |
| 70B 模型 + 8 卡 A100 | ZeRO-3 + TP=2 | 显存 + 通信平衡 |
| 70B 模型 + 32 卡 H100 | 3D 并行(TP=8, PP=2, DP=2) | 工业级 |
| 175B 模型 + 128 卡 H100 | 3D 并行 + ZeRO-3 + 序列并行 | 极致场景 |
| CPU 多机 | DeepSpeed Offload | 不推荐,IO 是瓶颈 |
经验公式:
模型参数量 ≤ 单卡显存 × 0.5 → DDP 足够
模型参数量 ≤ 单卡显存 × 2 → 加 ZeRO-2
模型参数量 ≤ 单卡显存 × 4 → 加 ZeRO-3
模型参数量 > 单卡显存 × 4 → 必须 TP/PP + ZeRO-3
雷点:很多人把"ZeRO-3 = 强"当作铁律,但ZeRO-3 通信开销很大。如果 ZeRO-2 能装下,永远别上 ZeRO-3。
三、整体架构:五种并行策略的本质
┌──────────────────────────────────────────────────────────────────────┐ │ 1. Data Parallel (DP/DDP) - 数据并行 │ │ │ │ GPU 0: [Model] [data shard 0] ─┐ │ │ GPU 1: [Model] [data shard 1] ├→ all-reduce(gradient) 同步 │ │ GPU 2: [Model] [data shard 2] │ │ │ GPU 3: [Model] [data shard 3] ─┘ │ │ │ │ 每张卡都有一份完整模型,数据切分到多卡 │ │ 显存: N × (模型+优化器+梯度) 通信: all-reduce 一次/步 │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ 2. Tensor Parallel (TP) - 张量并行 │ │ │ │ QKV 线性层被切到多卡: │ │ GPU 0: Q[0:n/2] K[0:n/2] V[0:n/2] ─┐ │ │ GPU 1: Q[n/2:n] K[n/2:n] V[n/2:n] ├→ all-reduce (forward) │ │ ... │ │ │ GPU 0: out_proj[0:n/2] │ │ │ GPU 1: out_proj[n/2:n] ─┘ │ │ │ │ 单层都被切分,需要频繁同步 │ │ 显存: 模型/N 通信: 每层 2 次 all-reduce │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ 3. Pipeline Parallel (PP) - 流水线并行 │ │ │ │ 模型的连续几层放到一张卡: │ │ GPU 0: layers 0-7 ──┐ │ │ GPU 1: layers 8-15 ├──→ 微批次 pipeline 通信 │ │ GPU 2: layers 16-23 │ (micro-batch 切换) │ │ GPU 3: layers 24-31 ─┘ │ │ │ │ 显存: 模型/N 通信: 点对点 (相对少) 代价: bubble 浪费 │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ 4. ZeRO (Zero Redundancy Optimizer) - 数据并行的优化版 │ │ │ │ 同样数据并行,但把优化器状态/梯度/参数分片到多卡 │ │ Stage 1: 分片优化器状态 → 显存省 4x │ │ Stage 2: + 分片梯度 → 显存省 8x │ │ Stage 3: + 分片参数 → 显存省 Nx (N=卡数) │ │ │ │ 通信: ZeRO-1/2 增加 ~1.5x,ZeRO-3 增加 ~2-3x │ └──────────────────────────────────────────────────────────────────────┘ ┌──────────────────────────────────────────────────────────────────────┐ │ 5. FSDP (Fully Sharded Data Parallel) - PyTorch 原生版 │ │ │ │ 等价于 ZeRO-3,完全在 PyTorch 内部实现 │ │ 优势: 和 PyTorch 生态完全兼容; │ │ 劣势: 早期版本比 DeepSpeed 慢,2.x 后差距已不大 │ └──────────────────────────────────────────────────────────────────────┘
四、核心流程拆解
4.1 显存计算公式(必背)
训练时的显存主要由四部分构成:
总显存 = 模型权重 + 优化器状态 + 梯度 + 激活值 ↓ ↓ ↓ ↓ P 2P (Adam) 2P (batch × seq × hidden × L) (FP16) (FP32 m,v) (FP16) (中间结果)
具体数字(7B 模型,FP16 训练):
| 组件 | 计算 | 显存 |
|---|---|---|
| 模型权重 | 7B × 2 bytes | 14GB |
| 优化器 (AdamW, FP32 m+v) | 7B × 8 bytes | 56GB |
| 梯度 (FP16) | 7B × 2 bytes | 14GB |
| 激活值 (L=32, batch=4, seq=2048) | 视具体值 | ~20GB |
| 总计 | ~104GB |
这就是为什么 7B 模型 Full FT 需要 A100 80G × 2 卡(即使不考虑激活)。
4.2 DDP 训练流程
输入:模型 + 数据 + 多卡 动作:
每张卡加载完整模型
数据分片到各卡(DistributedSampler)
前向 + 反向:各卡独立算梯度
同步梯度:all-reduce(最关键的一步)
各卡独立更新参数
异常处理:
通信失败:检查 NCCL 版本、网络(InfiniBand vs Ethernet)
负载不均:检查 DistributedSampler 是否正确设了 shuffle=False 的验证集
收敛慢:检查 learning rate 是否需要随 batch_size 线性放大
4.3 ZeRO 配置决策
7B 模型 + 24G 单卡 → 上 ZeRO-2 吗? ├─ 不需要,LoRA 就够 └─ 强行上 ZeRO-2? 通信开销不划算 70B 模型 + 40G 单卡 × 8 → 上哪个? ├─ ZeRO-3: 70B × 2 / 8 = 17.5GB/卡,装得下 └─ ZeRO-3 + TP=2: 通信更省,但配置复杂
经验值:
| 模型规模 | 推荐 |
|---|---|
| < 13B | ZeRO-1 / ZeRO-2 |
| 13B-70B | ZeRO-3 + TP=2 |
| 70B-175B | 3D 并行 |
4.4 通信开销(容易被低估)
| 策略 | 通信量/步 | 受什么限制 |
|---|---|---|
| DDP | 2×模型大小 | 带宽 |
| ZeRO-1 | 1.5× | 带宽 |
| ZeRO-2 | 1.5× | 带宽 |
| ZeRO-3 | 2-3× | 带宽 + 延迟 |
| TP | 每层 2×all-reduce | 延迟敏感 |
| PP | 点对点 (N-1)/N | bubble |
关键洞察:TP 受延迟影响大,ZeRO 受带宽影响大。InfiniBand 环境下 ZeRO-3 表现好,普通以太网下 TP 优于 ZeRO-3。
五、关键代码:四种方案的最小可跑示例
5.1 DDP(最基础,多机多卡)
启动:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=192.168.1.1 --master_port=29500 train_ddp.py
代码:
# train_ddp.py import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler def main(): # ===== 1. 初始化进程组 ===== dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) # ===== 2. 模型 + 数据 ===== model = MyModel().cuda(local_rank) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) # 混合精度 scaler = torch.cuda.amp.GradScaler() dataset = MyDataset() sampler = DistributedSampler(dataset) loader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=4) # ===== 3. 训练循环 ===== for epoch in range(3): sampler.set_epoch(epoch) # 关键:每 epoch 重设 seed for batch in loader: with torch.cuda.amp.autocast(): loss = model(batch) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad()
DDP 关键参数:
find_unused_parameters=True:模型有未用子模块时(多任务/条件分支)必须开,代价是 10-20% 通信开销
gradient_as_bucket_view=True(PyTorch 2.0+ 默认):DDP 内部把梯度打包成桶做 all-reduce,比逐参数 all-reduce 快 30%
static_graph=True:模型结构不变时设 True,省 20% 通信
5.2 DeepSpeed ZeRO-3(70B 模型实战)
启动:
deepspeed --num_gpus=8 --num_nodes=2 train_zero3.py
ds_config_zero3.json:
{
"bf16": {"enabled": true},
"optimizer": {
"type": "AdamW",
"params": {"lr": 1e-5, "betas": [0.9, 0.95], "weight_decay": 0.1}
},
"scheduler": {
"type": "WarmupCosineLR",
"params": {"warmup_min_ratio": 0.03, "warmup_num_steps": 200}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,// 通信和计算重叠,提速 10-20%
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}
train_zero3.py:
import deepspeed
from transformers import AutoModelForCausalLM, AutoTokenizer
# ===== 1. 加载模型 =====
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-72B-Instruct",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-72B-Instruct")
# ===== 2. DeepSpeed 初始化 =====
ds_engine = deepspeed.initialize(
model=model,
config="ds_config_zero3.json",
model_parameters=model.parameters(),
)
# ===== 3. 训练循环 =====
for step, batch in enumerate(loader):
loss = ds_engine.model(**batch).loss
ds_engine.backward(loss)
ds_engine.step()
ZeRO-3 关键参数解释:
stage: 3:完整 ZeRO-3(参数 + 梯度 + 优化器全分片)
offload_optimizer/pin_memory: true:把分片数据卸载到 CPU,省 GPU 显存 30%+,代价是慢 20%
overlap_comm: true:通信和计算重叠,提速关键
stage3_prefetch_bucket_size: "auto":预取参数,减少通信等待
stage3_max_live_parameters: 1e9:控制每张卡上同时驻留的参数上限,防显存爆掉
ZeRO 选型决策:
| 模型 / 卡 | ZeRO Stage | 理由 |
|---|---|---|
| 7B / 24G | LoRA 不上 ZeRO | 没必要 |
| 13B / 24G | ZeRO-2 | 优化器分片省 4x |
| 70B / 40G | ZeRO-3 | 必须分片参数 |
| 70B / 80G | ZeRO-2 + 混合精度 | ZeRO-2 装得下就别上 ZeRO-3 |
5.3 FSDP(PyTorch 原生,2024 后推荐)
FSDP 在 PyTorch 2.0+ 已经成熟,和 DeepSpeed ZeRO-3 性能持平,生态更好。
# train_fsdp.py
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
def main():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
torch_dtype=torch.bfloat16,
)
# ===== 关键: auto wrap policy =====
# 把每个 LlamaDecoderLayer 包成一个 FSDP 单元
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
# ===== 混合精度策略 =====
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# ===== FSDP 包装 =====
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 等价 ZeRO-3
mixed_precision=mp_policy,
auto_wrap_policy=wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
forward_prefetch=True,
device_id=local_rank,
limit_all_gathers=True, # 限制同时 gather 数量,防显存爆
use_orig_params=True, # 必须,兼容 optimizer
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
for batch in loader:
loss = model(**batch).loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
FSDP 关键参数:
sharding_strategy:
FULL_SHARD(默认):等价 ZeRO-3
SHARD_GRAD_OP:等价 ZeRO-2
NO_SHARD:等价 DDP
HYBRID_SHARD:节点内全分片,跨节点复制(多机训练的好选择)
auto_wrap_policy:决定哪些层独立分片。transformer_auto_wrap_policy 把每个 Transformer 层当一个 FSDP 单元,比整模型分片通信少 50%。
backward_prefetch=BACKWARD_PRE:反向时预取下一层参数,提速 10-15%
forward_prefetch=True:前向时预取下一层
limit_all_gathers=True:关键! 限制同时 gather 的层数,防 OOM
use_orig_params=True:保持原始参数引用,让 optimizer 能正常用
FSDP vs DeepSpeed 对比:
| 维度 | FSDP | DeepSpeed ZeRO-3 |
|---|---|---|
| 性能 | PyTorch 2.0+ 持平 | 持平 |
| 生态兼容 | 原生 PyTorch | 需要 ds_engine |
| 调试 | 容易(标准 PyTorch) | 难(黑盒) |
| CPU offload | 支持 | 支持 |
| Activation checkpointing | 原生 | 集成 |
| 文档/社区 | 越来越好 | 老牌稳 |
一句话记忆:新项目优先 FSDP,老项目用 DeepSpeed。两者性能差距已经很小(< 5%)。
5.4 3D 并行(70B+ 工业级)
TP + PP + ZeRO-3 组合,用 Megatron-LM 或 DeepSpeed + accelerate 实现。
# 3D 并行需要修改模型结构
# 这里以 accelerate + 简单 TP 为例
# 实际工业级用 Megatron-DeepSpeed
# accelerate 配置: configs/deepspeed_zero3_tp2.yaml
compute_environment: LOCAL_MACHINE
distributed_type: DEEPSPEED
deepspeed_config:
gradient_accumulation_steps: 16
zero_optimization:
stage: 3
offload_optimizer: {device: cpu}
tensor_parallel:
tp_size: 2# 张量并行度
pipeline_parallel:
pp_size: 2# 流水线并行度
mixed_precision: bf16
3D 并行的实际通信(TP=2, PP=2, DP=4, 8 卡):
模型被切 4 份(2×2),数据被切 4 份 通信: - TP: 每层 2 次 all-reduce(高频) - PP: 4 次点对点(低频) - ZeRO-3 DP: 4 卡间 all-gather/reduce-scatter(中频) 总通信开销: TP > ZeRO-3 > PP
3D 并行调参经验:
TP 优先 (TP=2-8):在节点内(NVLink 高速互联)
PP 次之 (PP=2-4):在节点间(带宽低)
DP/ZeRO 最后:剩下的卡
六、上线后如何评估训练效果
分布式训练出问题不容易排查,必须从一开始就有指标:
6.1 训练效率指标
| 指标 | 含义 | 健康值 |
|---|---|---|
| Throughput | tokens/秒/GPU | 与硬件算力比例正相关 |
| MFU | Model FLOPs Utilization | A100 上 40%-50%,H100 上 50%-60% |
| HTU | Hardware FLOPs Utilization | 同上 |
| Step time | 单步耗时 | 越稳定越好 |
| 通信占比 | comm_time / step_time | < 30% 为佳 |
测 throughput 代码:
# 训练 100 步后统计
start = time.time()
for step, batch in enumerate(loader):
if step == 100:
t = time.time() - start
tokens_per_sec = 100 * batch_size * seq_len * num_gpus / t
print(f"throughput: {tokens_per_sec:.0f} tokens/s")
break
6.2 显存监控
# 在训练循环里加
print(f"GPU {rank} mem: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
或者用 torch.cuda.memory_summary() 详细看:
allocated:当前分配
reserved:PyTorch 缓存
max_allocated:峰值
6.3 收敛性监控
| 指标 | 关注点 |
|---|---|
| loss 曲线 | 是否和单卡 DDP 收敛一致 |
| gradient norm | 异常大/小都是问题 |
| 各卡 loss 同步性 | DDP/PP 不同步是数据分片问题 |
| eval 指标 | 跨 epoch 是否稳定提升 |
雷点:分布式训练有时会"静默掉点"——loss 看起来正常但精度差 2-3%。必须做单卡 DDP baseline 对比。
七、常见坑与避雷(12 条实战总结)
坑 1:NCCL 超时
现象:RuntimeError: NCCL timeout。 根因:多机通信卡死或慢。 解决:
export NCCL_TIMEOUT=1800 # 30 分钟 export NCCL_IB_DISABLE=0 # 启用 InfiniBand export NCCL_DEBUG=INFO # 调试时打开,看哪个 rank 卡死
坑 2:CPU offload 把 IO 打爆
现象:ZeRO-3 + CPU offload 后训练极慢(每个 step 30s+)。 根因:CPU 内存带宽不够,PCIe 反复搬运数据。 解决:
去掉 offload,改用更多卡
用 NVMe SSD 做二级 offload(DeepSpeed 支持)
上 H100(96GB HBM3 减少 offload 需求)
坑 3:ZeRO-3 + 推理脚本不兼容
现象:ZeRO-3 训练完的 checkpoint 加载到推理脚本报错。 根因:ZeRO-3 训练时参数是分片的(每张卡只存 1/N),保存需要 gather。 解决:
# DeepSpeed 保存时 if ds_engine.global_rank == 0: ds_engine.save_checkpoint(save_dir) # 加载时 ds_engine.load_checkpoint(load_dir)
坑 4:activation checkpointing 选错粒度
现象:开了 gradient checkpointing 后显存省了 50%,但速度慢 40%。 根因:粒度太细(每个线性层都 checkpoint),重算太多。 解决:
# 在 Transformers 模型里
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
# 选 "每个 Transformer 层" 一个 checkpoint,粒度最合适
坑 5:TP/PP 切错位置
现象:TP=2 训练时 loss 震荡,TP=4 时直接发散。 根因:TP 切分必须按维度对齐(attention head 数必须能整除 TP 度)。 解决:
num_attention_heads % tp_size == 0
hidden_size % tp_size == 0
不满足时改用其他并行方式
坑 6:FSDP 的 use_orig_params 漏掉
现象:FSDP 训练时 optimizer.step() 报错。 根因:FSDP 默认会重写参数,optimizer 拿不到原始引用。 解决:
model = FSDP(model, use_orig_params=True, ...) # 必须
坑 7:gradient_accumulation_steps 配错
现象:loss 不下降或下降极慢。 根因:等效 batch size 太大(ZeRO 下被分片放大),学习率没跟上。 解决:
# DDP 下 batch=32,ZeRO-3 下等效 batch 仍是 32,但梯度更新更稀疏 # 学习率建议调小(10% 起步)
坑 8:数据并行 + 不平衡数据集
现象:某些 epoch 特别慢(卡之间负载不均)。 根因:DistributedSampler 按样本数切分,长样本全在一张卡。 解决:
sampler = DistributedSampler( dataset, shuffle=True, seed=42, drop_last=True, ) # 用 GroupByLengthSampler 按长度分组
坑 9:混合精度选错
现象:FP16 训练时 loss 突然变 NaN。 根因:FP16 数值范围小,gradient overflow 频繁。 解决:
优先用 BF16(A100+ 支持,范围和 FP32 一样)
FP16 必须配 GradScaler(前面 DDP 代码)
检查 loss.item() 是否出现 inf/nan
坑 10:多机训练卡在 init
现象:torchrun 启动后所有进程 hang 在初始化。 根因:网络不通、端口被防火墙挡、共享文件系统问题。 解决:
# 测试网络 nc -zv master_addr 29500 # 关闭防火墙(或开放端口) # 用 Gloo 调试(CPU) dist.init_process_group(backend="gloo")
坑 11:TP 度对不上 hidden size
现象:TP=8 训练时某些层报 "shape mismatch"。 根因:hidden_size=4096, num_heads=32, head_dim=128;TP=8 时 head_dim / TP = 16(OK),但 QKV 合并层要求 hidden_size 能整除 TP。 解决:
TP 度只能是 head 数的因数
7B 模型 head=32 → TP 只能 1/2/4/8/16/32
13B 模型 head=40 → TP 只能 1/2/4/5/8/10/20/40
坑 12:checkpoint 保存太慢
现象:ZeRO-3 保存 70B checkpoint 要 30 分钟+。 根因:每张卡只有 1/N 参数,保存时需要 gather 到 rank 0。 解决:
# 用 stage3_gather_16bit_weights_on_model_save: true # 只在 rank 0 触发 gather # 用 safetensors 单文件格式
最后一条:分布式训练最大的隐性成本不是 GPU,而是跨机调试。永远从单卡 DDP 调通,再上分布式。
八、优化方向(从"能用"到"好用")
按性价比从高到低:
混合精度 (BF16):零成本提速 30%-50%
Flash Attention 2:提速 30%,省显存 20%
Gradient Checkpointing:省显存 30%-50%,慢 20%
FSDP/ZeRO-2:把多卡的"数据并行"变成"数据并行 + 优化器分片"
Optimizer CPU Offload:省显存,慢 20%
Fused Optimizer(DeepSpeed):AdamW 改成 FusedAdam,省显存 30%
Communication Overlap:通信和计算重叠,提速 10%-20%
Sequence Parallelism:超长序列(> 8k)必备
ZeRO-Infinity / ZeRO++:极致优化版
3D Parallelism:终极方案
单卡训练 → DDP → ZeRO-2 → ZeRO-3 → TP+PP+ZeRO,按规模逐步升级。
何时该考虑别的方案:
单卡就能跑 → 别分布式,复杂度不值
8 卡 DDP 还慢 → 数据并行上限到了,上 TP/PP
模型 < 1B → 多半不需要分布式
推理性能问题 → 量化 + 蒸馏,不要用训练手段解决
九、收尾
分布式训练的核心是按规模选策略:
| 模型规模 | 推荐 |
|---|---|
| < 7B | 单卡 DDP 或 LoRA(前面文章讲过) |
| 7B-13B | DDP + ZeRO-2 |
| 13B-70B | FSDP / ZeRO-3 + TP=2-4 |
| 70B+ | 3D 并行(TP + PP + ZeRO) |
三个最容易犯的错:
小模型用 ZeRO-3:通信开销大于收益
大模型用 DDP:直接 OOM
不调 baseline 就上分布式:90% 的"分布式训练"问题都是单卡调通后才能复现的
决策流程:
单卡 DDP 跑通 + 收敛 → 验证算法
加 ZeRO-2 / FSDP → 验证显存
加 TP/PP → 验证规模
多机多卡 → 验证网络
生产稳定 → 监控 + 弹性调度
什么时候该考虑别的方案?当分布式训练本身成了瓶颈——比如成本太高、训练太慢、调试太难——这时候该回头看是不是模型选太大了:能不能用更小的 base model + LoRA 微调(前面文章讲过)?能不能用蒸馏让大模型教小模型?大模型工程不是堆算力,而是"用最小的资源达到业务效果"——这一原则贯穿了前面三篇(微调、量化、并行),也是所有 AI 工程师最终要回到的朴素出发点。
全部0条评论
快来发表一下你的评论吧 !