当前位置: 首页 > web >正文

DataParallel (DP) DistributedDataParallel (DDP)

DataParallel (DP) 和DistributedDataParallel (DDP)的区别

在 PyTorch 中,DataParallel (DP) 是将总 batch size 均分至每个 GPU 的实现方式,而 DistributedDataParallel (DDP) 则是每个 GPU 独立设置 batch size。具体区别如下:

1. DataParallel (DP):自动均分总 batch size

  • 特点

    • 单进程多线程,所有 GPU 由主进程(rank 0)控制。
    • 输入的 batch_size 表示所有 GPU 上的总样本数,会被自动均分至每个 GPU。
    • 例如:batch_size=128 且使用 4 个 GPU 时,每个 GPU 处理 128/4 = 32 个样本。
  • 代码示例

    model = nn.DataParallel(model)  # 包裹模型
    train_loader = DataLoader(dataset, batch_size=128)  # 总 batch size
    
  • 缺点

    • 主进程成为瓶颈,通信开销大,效率低于 DDP。
    • 内存使用不均衡(主 GPU 占用更多)。
    • 不支持模型并行(仅数据并行)。

2. DistributedDataParallel (DDP):每个 GPU 独立设置 batch size

  • 特点

    • 多进程模式,每个 GPU 对应一个独立进程。
    • 输入的 batch_size 表示每个 GPU 上的样本数,全局 batch size = 本地 batch size × GPU 数量。
    • 例如:batch_size=32 且使用 4 个 GPU 时,总 batch size 为 32×4 = 128
  • 代码示例

    # 初始化 DDP
    torch.distributed.init_process_group(backend='nccl')
    local_rank = torch.distributed.get_rank()
    model = model.to(local_rank)
    model = nn.DistributedDataParallel(model, device_ids=[local_rank])# 每个 GPU 的 DataLoader 加载本地 batch size
    train_loader = DataLoader(dataset, batch_size=32)  # 每个 GPU 32 样本
    
  • 优点

    • 高效通信(使用 NCCL 后端),支持大规模分布式训练。
    • 内存使用更均衡,训练速度更快

DP 与单卡训练的核心区别

DataParallel 的核心作用是利用多卡进行数据并行训练,与单卡训练的关键差异体现在以下方面:

1. 计算方式

单卡训练:
模型和数据都放在单个 GPU 上,所有计算(前向传播、反向传播)都在这张卡上完成。
例如:batch_size=128 时,128 个样本全部在 GPU 0 上处理。

DP 训练:
主进程(通常是 GPU 0)将模型复制到所有参与训练的 GPU 上(如 GPU 0、1、2、3)。
输入的 batch_size=128 会被均分到每个 GPU(例如 4 卡时,每卡处理 32 个样本)。
每个 GPU 独立进行前向传播,计算各自的损失。
所有 GPU 的梯度会被汇总到主 GPU(GPU 0),主 GPU 执行参数更新后,再将更新后的权重同步到其他 GPU。

2. 性能与适用场景

单卡训练:
优点:实现简单,无多卡通信开销。
缺点:受限于单卡显存,无法处理过大的模型或 batch size。

DP 训练:
优点:只需一行代码(model = nn.DataParallel(model))即可实现多卡数据并行,适合中小规模多卡场景(如 2-4 卡)。

缺点:主 GPU(GPU 0)承担更多通信和计算任务(如汇总梯度、更新参数),容易成为瓶颈,内存占用也更高。
效率低于 DDP(多进程模式),不适合大规模分布式训练(如 8 卡以上)。

http://www.xdnf.cn/news/16665.html

相关文章:

  • JavaWeb学习打卡18(JDBC案例详解)
  • [leetcode] 电话号码的排列组合
  • CSRF漏洞原理
  • CentOS7 安装和配置教程
  • USRP X410 X440 5G及未来通信技术的非地面网络(NTN)
  • Matplotlib(三)- 图表辅助元素
  • 经典算法题解析:从思路到实现,掌握核心编程思维
  • 天学网面试总结 —— 前端开发岗
  • Go 语言-->指针
  • 【2025/07/28】GitHub 今日热门项目
  • 【服务器知识】nginx配置ipv6支持
  • 大模型的开发应用(十九):AIGC基础
  • 【Spring WebFlux】 三、响应式流规范与实战
  • Java 笔记 serialVersionUID
  • ADB+Python控制(有线/无线) Scrcpy+按键映射(推荐)
  • 服务器查日志太慢,试试grep组合拳
  • 时序数据库选型指南:工业大数据场景下基于Apache IoTDB技术价值与实践路径
  • 5 分钟上手 Firecrawl
  • 【办公类-109-01】20250728托小班新生挂牌(学号姓名)
  • API产品升级丨全知科技发布「知影-API风险监测平台」:以AI重构企业数据接口安全治理新范式
  • 企业级日志分析系统ELK
  • Pycaita二次开发基础代码解析:点距测量、对象层级关系与选择机制深度剖析
  • 基于DeepSeek大模型和STM32的矿井“围压-温度-开采扰动“三位一体智能监测系统设计
  • 边缘计算+前端实时性:本地化数据处理在设备监控中的响应优化实践
  • vue element 封装表单
  • STM32时钟源
  • GaussDB as的用法
  • 【氮化镓】GaN同质外延p-i-n二极管中星形与三角形扩展表面缺陷的电子特性
  • 力扣 hot100 Day58
  • LeetCode 2044.统计按位或能得到最大值的子集数目:二进制枚举/DFS回溯(剪枝)