以MNIST数据集为例进行单机多卡训练(DP和DDP)
在单机多卡环境下使用PyTorch训练MNIST数据集时,可以通过DataParallel (DP) 和 DistributedDataParallel (DDP) 两种方式实现多卡并行。以下是具体实现示例和对比:
1. DataParallel (DP) 方式
DP是单进程多线程的简单并行方式,将模型复制到多个GPU,数据切分后分发到不同GPU计算,最后在主GPU聚合梯度。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x.view(x.size(0), -1))# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)# 初始化模型和优化器
model = Net()
model = nn.DataParallel(model) # 包装为DP模式
model = model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环
for epoch in range(5):for data, target in train_loader:data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = nn.CrossEntropyLoss()(output, target)loss.backward()optimizer.step()print(f'Epoch {epoch}, Loss: {loss.item()}')
DP的缺点:
- 单进程控制多卡,存在GIL锁限制。
- 主GPU显存瓶颈(需聚合梯度)。
- 效率低于DDP。
2. DistributedDataParallel (DDP) 方式
DDP是多进程并行,每个GPU独立运行一个进程,通过NCCL通信同步梯度,效率更高且无主GPU瓶颈。
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, DistributedSamplerdef setup(rank, world_size):dist.init_process_group("nccl", rank=rank, world_size=world_size)def cleanup():dist.destroy_process_group()class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Linear(784, 10)def forward(self, x):return self.fc(x.view(x.size(0), -1))def train(rank, world_size):setup(rank, world_size)# 每个进程独立加载数据(使用DistributedSampler)transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)train_loader = DataLoader(train_dataset, batch_size=64, sampler=sampler)# 初始化模型和优化器model = Net().to(rank)model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环for epoch in range(5):sampler.set_epoch(epoch) # 确保每个epoch的shuffle不同for data, target in train_loader:data, target = data.to(rank), target.to(rank)optimizer.zero_grad()output = model(data)loss = nn.CrossEntropyLoss()(output, target)loss.backward()optimizer.step()if rank == 0: # 仅主进程打印print(f'Epoch {epoch}, Loss: {loss.item()}')cleanup()if __name__ == '__main__':world_size = torch.cuda.device_count()mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
DDP的关键点:
- 多进程启动:
mp.spawn
启动多个进程,每个进程绑定一个GPU。 - 进程组初始化:
init_process_group
设置NCCL后端。 - 数据分片:
DistributedSampler
确保每个进程读取不同数据。 - 模型包装:
DistributedDataParallel
自动同步梯度。
DP vs DDP 对比
特性 | DataParallel (DP) | DistributedDataParallel (DDP) |
---|---|---|
并行模式 | 单进程多线程 | 多进程 |
通信效率 | 低(主GPU聚合瓶颈) | 高(NCCL直接通信) |
显存占用 | 主GPU显存压力大 | 各GPU显存均衡 |
代码复杂度 | 简单(无需修改数据加载) | 较复杂(需配置进程组和Sampler) |
适用场景 | 快速原型开发 | 生产环境大规模训练 |
总结
- DP适合快速验证多卡可行性,但效率低。
- DDP是PyTorch官方推荐的多卡训练方式,适合实际生产环境。