PyTorch DDP 随机卡死复盘
PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last
一次真实的分布式训练“玄学卡死”:2 卡训练偶发在 epoch 尾部停住不动,GPU 利用率掉到 0%,日志无异常。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch 引发各 rank 步数不一致,导致 allreduce 永久等待。
技术环境
OS:Ubuntu 22.04
Python:3.10.13
PyTorch:2.2.2 + CUDA 12.1(torch==2.2.2+cu121)
NCCL:2.18(系统自带,未自编译)
GPU:2×RTX 3090(24GB)
启动方式:torchrun --standalone --nproc_per_node=2 train.py
Bug 现象
训练随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。
偶尔能看到 NCCL 打印(并不总出现):
NCCL WARN Reduce failed: … Async operation timed out
kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce 上(DistributedDataParallel 内部)。
关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。
最小可复现(错误版)
问题点集中在 数据划分不均 + Sampler 误用:
shuffle=True 与 DistributedSampler 混用(会被忽略但容易误导)。
drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致(当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。
每个 epoch 忘记调用 sampler.set_epoch(epoch),导致各 rank 的随机顺序不一致。
train_ddp_wrong.py —— 错误示例(请勿照抄)
import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler
class DummyDS(Dataset):
def init(self, N=1003): # 刻意设成非 world_size 整数倍
self.N = N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,)) # 模拟有时会丢弃某些样本的增强(省略)
return x, y
def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))
def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))
ds = DummyDS()
sampler = DistributedSampler(ds, shuffle=True, drop_last=False) # ❌ drop_last=False
# ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)
model = DDP(model, device_ids=[device.index])
opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ❌ 忘记 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward() # 🔥 偶发卡在这里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()
if name == “main”:
main()
触发条件(满足一两个就可能复现):
len(dataset) 不是 world_size 的整数倍。
动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。
忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。
drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。
某些自定义 collate_fn 在“空 batch”时直接 continue。
排查步骤
1)先确认“各 rank 步数一致”
在训练 loop 里加统计(不要只在 rank0 打印):
from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
或每个 rank 各自 print,检查是否相等
我的现象:有的 epoch,rank0 比 rank1 多 1–2 个 step。
2)开启 NCCL 调试
在启动前设置:
export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1
再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。
3)检查 Sampler 与 DataLoader 参数
DistributedSampler 必须搭配 sampler.set_epoch(epoch)。
DataLoader 里不要再写 shuffle=True。
若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。
解决方案(修复版)
✅ 方案 A:严格对齐 Sampler 语义 + 丢最后不齐整的 batch
train_ddp_fixed.py —— 推荐修复
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset
class DummyDS(Dataset):
def init(self, N=1003): self.N=N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,))
return x, y
def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))
def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))
ds = DummyDS()
# 关键 1:使用 DistributedSampler,统一交给它洗牌
sampler = DistributedSampler(ds, shuffle=True, drop_last=True) # ✅
# 关键 2:DataLoader 里不要再写 shuffle
loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)
ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False) # 如无动态分支,关掉更稳更快
opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch) # ✅ 关键 3:每个 epoch 设置不同随机种子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier() # ✅ 收尾同步,避免 rank 提前退出
dist.destroy_process_group()
if name == “main”:
main()
✅ 方案 B:必须保留最后一批(学术场景)
如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小:
Padding/Repeat:在 collate_fn 里把最后一批补齐到一致大小;
EvenlyDistributedSampler:自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。
示例(最简单的“循环补齐”):
class EvenSampler(DistributedSampler):
def iter(self):
# 先拿到原始 index,再做均匀补齐
indices = list(super().iter())
# 使得 len(indices) 可整除 num_replicas
rem = len(indices) % self.num_replicas
if rem != 0:
pad = self.num_replicas - rem
indices += indices[:pad] # 简单重复前几个样本
return iter(indices)
✅ 方案 C:降低“意外丢样”风险
自定义 collate_fn 不要在空 batch 时 return None 或直接 continue,而应抛异常或做补齐。
数据增强/过滤若可能丢样,务必在 Dataset 内重采样,保证 getitem 总是返回有效样本。
若模型里有条件分支可能不参与反向(导致“未使用参数”),
要么收敛后改为固定分支;
要么在 DDP 里开启 find_unused_parameters=True(但会更慢,且仍需确保步数一致)。
验证
修复后,连续训练 50+ 个 epoch 未再出现挂起;
加上 dist.barrier() 收尾,脚本结束更干净;
打开 NCCL_BLOCKING_WAIT=1 时也不再报超时。
避坑总结(Checklist)
一定要 sampler.set_epoch(epoch),确保各 rank 洗牌一致。
不要在 DataLoader 再写 shuffle=True(使用 DistributedSampler 时交给 sampler)。
尽量 drop_last=True,避免尾批大小不一致;若必须保留尾批,就补齐到等长。
保证各 rank 步数完全一致:collate 不能静默丢 batch;Dataset 不要“偶发返回 None”。
按需设置 DDP 参数:无动态分支时 find_unused_parameters=False 更稳更快。
开 NCCL 调试:NCCL_DEBUG=INFO、NCCL_ASYNC_ERROR_HANDLING=1、NCCL_BLOCKING_WAIT=1,排障高效。
收尾同步:退出前 dist.barrier(),避免某 rank 早退影响他人。
最简复现先做整除长度:把 len(dataset) 设为 k * world_size,观察是否立刻恢复。
以上是这次 DDP 卡死问题从现象 → 排查 → 解决的完整记录。这个坑非常高频,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗