当前位置: 首页 > web >正文

以MNIST数据集为例进行单机多卡训练(DP和DDP)

在单机多卡环境下使用PyTorch训练MNIST数据集时,可以通过DataParallel (DP)DistributedDataParallel (DDP) 两种方式实现多卡并行。以下是具体实现示例和对比:


1. DataParallel (DP) 方式

DP是单进程多线程的简单并行方式,将模型复制到多个GPU,数据切分后分发到不同GPU计算,最后在主GPU聚合梯度。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x.view(x.size(0), -1))# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型和优化器
model = Net()
model = nn.DataParallel(model)  # 包装为DP模式
model = model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环
for epoch in range(5):for data, target in train_loader:data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = nn.CrossEntropyLoss()(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch}, Loss: {loss.item()}')

DP的缺点

  • 单进程控制多卡,存在GIL锁限制。
  • 主GPU显存瓶颈(需聚合梯度)。
  • 效率低于DDP。

2. DistributedDataParallel (DDP) 方式

DDP是多进程并行,每个GPU独立运行一个进程,通过NCCL通信同步梯度,效率更高且无主GPU瓶颈。

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, DistributedSamplerdef setup(rank, world_size):dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():dist.destroy_process_group()class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x.view(x.size(0), -1))def train(rank, world_size):setup(rank, world_size)# 每个进程独立加载数据(使用DistributedSampler)transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)train_loader = DataLoader(train_dataset, batch_size=64, sampler=sampler)# 初始化模型和优化器model = Net().to(rank)model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环for epoch in range(5):sampler.set_epoch(epoch)  # 确保每个epoch的shuffle不同for data, target in train_loader:data, target = data.to(rank), target.to(rank)optimizer.zero_grad()output = model(data)loss = nn.CrossEntropyLoss()(output, target)loss.backward()optimizer.step()if rank == 0:  # 仅主进程打印print(f'Epoch {epoch}, Loss: {loss.item()}')cleanup()if __name__ == '__main__':world_size = torch.cuda.device_count()mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

DDP的关键点

  1. 多进程启动mp.spawn 启动多个进程,每个进程绑定一个GPU。
  2. 进程组初始化init_process_group 设置NCCL后端。
  3. 数据分片DistributedSampler 确保每个进程读取不同数据。
  4. 模型包装DistributedDataParallel 自动同步梯度。

DP vs DDP 对比

特性DataParallel (DP)DistributedDataParallel (DDP)
并行模式单进程多线程多进程
通信效率低(主GPU聚合瓶颈)高(NCCL直接通信)
显存占用主GPU显存压力大各GPU显存均衡
代码复杂度简单(无需修改数据加载)较复杂(需配置进程组和Sampler)
适用场景快速原型开发生产环境大规模训练

总结

  • DP适合快速验证多卡可行性,但效率低。
  • DDP是PyTorch官方推荐的多卡训练方式,适合实际生产环境。
http://www.xdnf.cn/news/14224.html

相关文章:

  • 每日算法刷题Day31 6.14:leetcode二分答案2道题,结束二分答案,开始枚举技巧,用时1h10min
  • 【生活系列】金刚经
  • 使用 FastMCP 实现 Word 文档与 JSON 数据互转的 Python 服务
  • PHP、Apache环境中部署sqli-labs
  • 【构建】C++包管理器介绍
  • 从0开始学习语言模型--Day01--亲自构筑语言模型的重要性
  • python中的异常处理try-except - else - finally与自定义异常处理
  • R语言文本探索与预处理:入门指南
  • PH热榜 | 2025-06-14
  • C++开源协程库async_simple有栈协程源码分析
  • SQL Server 窗口函数详解:窗口行数控制的原理、关键字与应用场景
  • 计算机网络-自顶向下—第五章数据链路层重点复习笔记
  • Thread的join方法
  • python+django/flask+uniapp宠物中心信息管理系统app
  • Java开发中避免NullPointerException的全面指南
  • 【三维重建】无位姿图像的大场景On-the-fly重建
  • 【Linux】初见,进程概念
  • 创客匠人解析:美团护城河战略对 IP 可持续变现的启示
  • TCP 协议
  • 2025年EAAI SCI1区TOP,贪婪策略粒子群算法GS-IPSO+无人机桥梁巡检覆盖路径规划,深度解析+性能实测
  • 函数式编程 stream流 lambda表达式
  • event.target 详解:理解事件目标对象
  • 学习昇腾开发的第二天--PC机远程登录开发板
  • 大IPD之——华为的管理变革与战略转型之道(三)
  • 05-Linux软件安装与前后端项目部署
  • adoc(asciidoc)转为markdown的方法,把.adoc文件转换为markdown格式
  • PostgreSQL的扩展pg_visibility
  • 【办公类-25-05】20250514 Python模拟UIBOT上传园园通截图(自动最小化界面,时间部分的删除和黏贴)
  • 【CSS-13】CSS 网页布局三大机制详解:普通流、浮动与定位
  • 2.2 订阅话题