当前位置: 首页 > backend >正文

PyTorch DDP 随机卡死复盘

PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

一次真实的分布式训练“玄学卡死”:2 卡训练偶发在 epoch 尾部停住不动,GPU 利用率掉到 0%,日志无异常。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch 引发各 rank 步数不一致,导致 allreduce 永久等待。

技术环境

OS:Ubuntu 22.04

Python:3.10.13

PyTorch:2.2.2 + CUDA 12.1(torch==2.2.2+cu121)

NCCL:2.18(系统自带,未自编译)

GPU:2×RTX 3090(24GB)

启动方式:torchrun --standalone --nproc_per_node=2 train.py

Bug 现象

训练随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。

偶尔能看到 NCCL 打印(并不总出现):

NCCL WARN Reduce failed: … Async operation timed out

kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce 上(DistributedDataParallel 内部)。

关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

最小可复现(错误版)

问题点集中在 数据划分不均 + Sampler 误用:

shuffle=True 与 DistributedSampler 混用(会被忽略但容易误导)。

drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致(当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

每个 epoch 忘记调用 sampler.set_epoch(epoch),导致各 rank 的随机顺序不一致。

train_ddp_wrong.py —— 错误示例(请勿照抄)

import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSampler

class DummyDS(Dataset):
def init(self, N=1003): # 刻意设成非 world_size 整数倍
self.N = N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,)) # 模拟有时会丢弃某些样本的增强(省略)
return x, y

def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))

def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))
ds = DummyDS()

sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False
# ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)
loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)
model = DDP(model, device_ids=[device.index])
opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ❌ 忘记 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward()      # 🔥 偶发卡在这里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()

if name == “main”:
main()

触发条件(满足一两个就可能复现):

len(dataset) 不是 world_size 的整数倍。

动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

某些自定义 collate_fn 在“空 batch”时直接 continue。

排查步骤
1)先确认“各 rank 步数一致”

在训练 loop 里加统计(不要只在 rank0 打印):

from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):
steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)

或每个 rank 各自 print,检查是否相等

我的现象:有的 epoch,rank0 比 rank1 多 1–2 个 step。

2)开启 NCCL 调试

在启动前设置:

export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3)检查 Sampler 与 DataLoader 参数

DistributedSampler 必须搭配 sampler.set_epoch(epoch)。

DataLoader 里不要再写 shuffle=True。

若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

解决方案(修复版)
✅ 方案 A:严格对齐 Sampler 语义 + 丢最后不齐整的 batch

train_ddp_fixed.py —— 推荐修复

import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Dataset

class DummyDS(Dataset):
def init(self, N=1003): self.N=N
def len(self): return self.N
def getitem(self, i):
x = torch.randn(32, 3, 224, 224)
y = torch.randint(0, 10, (32,))
return x, y

def setup():
dist.init_process_group(“nccl”)
torch.cuda.set_device(int(os.environ[“LOCAL_RANK”]))

def main():
setup()
rank = dist.get_rank()
device = torch.device(“cuda”, int(os.environ[“LOCAL_RANK”]))

ds = DummyDS()
# 关键 1:使用 DistributedSampler,统一交给它洗牌
sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅
# 关键 2:DataLoader 里不要再写 shuffle
loader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)
ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快
opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出
dist.destroy_process_group()

if name == “main”:
main()

✅ 方案 B:必须保留最后一批(学术场景)

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小:

Padding/Repeat:在 collate_fn 里把最后一批补齐到一致大小;

EvenlyDistributedSampler:自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的“循环补齐”):

class EvenSampler(DistributedSampler):
def iter(self):
# 先拿到原始 index,再做均匀补齐
indices = list(super().iter())
# 使得 len(indices) 可整除 num_replicas
rem = len(indices) % self.num_replicas
if rem != 0:
pad = self.num_replicas - rem
indices += indices[:pad] # 简单重复前几个样本
return iter(indices)

✅ 方案 C:降低“意外丢样”风险

自定义 collate_fn 不要在空 batch 时 return None 或直接 continue,而应抛异常或做补齐。

数据增强/过滤若可能丢样,务必在 Dataset 内重采样,保证 getitem 总是返回有效样本。

若模型里有条件分支可能不参与反向(导致“未使用参数”),

要么收敛后改为固定分支;

要么在 DDP 里开启 find_unused_parameters=True(但会更慢,且仍需确保步数一致)。

验证

修复后,连续训练 50+ 个 epoch 未再出现挂起;

加上 dist.barrier() 收尾,脚本结束更干净;

打开 NCCL_BLOCKING_WAIT=1 时也不再报超时。

避坑总结(Checklist)

一定要 sampler.set_epoch(epoch),确保各 rank 洗牌一致。

不要在 DataLoader 再写 shuffle=True(使用 DistributedSampler 时交给 sampler)。

尽量 drop_last=True,避免尾批大小不一致;若必须保留尾批,就补齐到等长。

保证各 rank 步数完全一致:collate 不能静默丢 batch;Dataset 不要“偶发返回 None”。

按需设置 DDP 参数:无动态分支时 find_unused_parameters=False 更稳更快。

开 NCCL 调试:NCCL_DEBUG=INFO、NCCL_ASYNC_ERROR_HANDLING=1、NCCL_BLOCKING_WAIT=1,排障高效。

收尾同步:退出前 dist.barrier(),避免某 rank 早退影响他人。

最简复现先做整除长度:把 len(dataset) 设为 k * world_size,观察是否立刻恢复。

以上是这次 DDP 卡死问题从现象 → 排查 → 解决的完整记录。这个坑非常高频,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗

http://www.xdnf.cn/news/20018.html

相关文章:

  • < 自用文 OS 有关 > (续)发现正在被攻击 后的自救 Fail2ban + IPset + UFW 工作流程详解
  • 十四、STM32-----低功耗
  • 【前端教程】JavaScript DOM 操作案例解析与代码优化
  • 不用服务器也能监控网络:MyIP+cpolar让中小企业告别昂贵方案
  • 【全网最全】《2025国赛/高教杯》C题 思路+代码python和matlab+文献 一到四问 退火算法+遗传算法 NIPT的时点选择与胎儿的异常判定
  • Qt 系统相关 - 1
  • 大整数乘法实现日志:从查表法到逐位运算
  • 基于深度掩码的动态模糊处理
  • 《Html泛型魔法学院:用霍格沃茨风格网页教授集合框架》
  • SpringBoot 集成 MyBatis-Plus 的使用指南
  • 学习PaddlePaddle--环境配置-Windows 11 + RTX 4060
  • 优质技术博客分享(第1期)
  • Beautiful.ai:AI辅助PPT工具高效搞定排版,告别熬夜做汇报烦恼
  • maven settings.xml文件的各个模块、含义以及它们之间的联系
  • 阿瓦隆 A1146 Pro 63T:性能与设计详解,探索区块链挖矿新高度
  • 【网工基础】20+常用网络协议介绍
  • 水下管道巡检机器人结构设cad+三维图+设计说明书
  • 2508C++,skia动画
  • 【iOS】对象复制与属性关键字
  • 同步安卓手机的照片到NAS的方案(完美)
  • 人工智能学习:鸢尾花数据获取
  • qwen-code 功能分析报告
  • 软件安装教程(四):在 Windows 上安装与配置 MATLAB(超详细)
  • 【2025企业建站推荐指南】深度解析十大顶尖网站建设公司:从品牌设计到技术落地的全维度解决方案
  • 01_配置版本
  • BERT家族进化史:从BERT到LLaMA,每一次飞跃都源于对“学习”的更深理解
  • 【面试题】生成式搜索能否保证top-1的准确性?
  • MySQL中CASE语法规则的详细解析及扩展示例
  • Spring Cloud Alibaba快速入门01
  • 去中心化投票系统开发教程