模型剪枝的定义与核心作用
模型剪枝(Model Pruning)是一种通过移除神经网络中冗余参数或结构(如权重、神经元、注意力头等)来压缩模型的技术。其核心目标是在保持模型性能的前提下,降低计算复杂度、存储需求和推理延迟。
核心作用:
-
降低计算成本:减少浮点运算量(FLOPs),提升推理速度(如剪枝50%参数,性能仅下降1-3%)。
-
减少内存占用:压缩模型体积,便于在边缘设备(如手机、IoT设备)部署。
-
硬件友好性:结构化剪枝移除整个卷积核/通道,支持GPU/TPU加速。
-
缓解过拟合:去除冗余参数可能增强泛化能力。
1、PyTorch完整实现:MNIST分类模型剪枝
- 环境准备与数据加载
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune
- 数据预处理
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)
- 定义全连接网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return F.log_softmax(x, dim=1)
- 训练原始模型
def train(model, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 200 == 0:print(f'Epoch {epoch} | Loss: {loss.item():.4f}')model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
- 训练3个epoch
# 训练3个epoch
for epoch in range(1, 4):train(model, optimizer, epoch)
- 应用L1非结构化剪枝
#应用L1非结构化剪枝
def apply_pruning(model, amount=0.3):# 全局剪枝:统一分配各层剪枝比例parameters_to_prune = ((model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight'),)prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=amount)
- 剪除30%权重
# 剪除30%权重
apply_pruning(model)
print(f"剪枝后稀疏度:fc1={100*torch.sum(model.fc1.weight==0)/model.fc1.weight.nelement():.1f}%")
- 微调剪枝后模型,微调2个epoch
# 微调2个epoch
for epoch in range(1, 3):train(model, optimizer, epoch)
- 性能评估
def test(model):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()accuracy = 100. * correct / len(test_loader.dataset)print(f'Test Accuracy: {accuracy:.2f}%')test(model)
- 统计 稀疏性 函数
# 统计稀疏性
def print_sparsity(model):for name, param in model.named_parameters():if 'weight' in name:sparsity = 100 * (param == 0).sum().item() / param.numel()print(f"{name} 稀疏度: {sparsity:.2f}%")
- 结构化剪枝实现,剪除卷积层50% weight; 使用prune.remove()永久移除被剪枝参数,生成稀疏模型。
#结构化剪枝实现(卷积层示例)
# 示例:对卷积层进行L2范数剪枝(移除50%输出通道)
conv = torch.nn.Conv2d(3, 64, kernel_size=3)
print_sparsity(conv)
prune.ln_structured(conv, name="weight", amount=0.5, n=2, dim=0) # L2范数剪枝
prune.remove(conv, "weight") # 永久移除剪枝掩码
print_sparsity(conv)
2、手动实现L1权重剪枝(剪除30%权重)
def manual_pruning(model, amount=0.3):for name, param in model.named_parameters():# print(name)if 'weight' in name and 'fc' in name: # 仅处理全连接层的权重weights = param.data.abs() # 取绝对值作为重要性度量threshold = torch.quantile(weights.flatten(), amount) # 计算剪枝阈值mask = (weights > threshold).float() # 生成掩码(保留高于阈值的权重)param.data *= mask # 应用掩码(剪枝)# 执行剪枝
manual_pruning(model, amount=0.3)# 统计稀疏性
def print_sparsity(model):for name, param in model.named_parameters():if 'weight' in name:sparsity = 100 * (param == 0).sum().item() / param.numel()print(f"{name} 稀疏度: {sparsity:.2f}%")print_sparsity(model)