当前位置: 首页 > news >正文

模型剪枝的定义与核心作用

模型剪枝(Model Pruning)是一种通过移除神经网络中冗余参数或结构(如权重、神经元、注意力头等)来压缩模型的技术。其核心目标是在保持模型性能的前提下,降低计算复杂度、存储需求和推理延迟。

核心作用:

  • 降低计算成本:减少浮点运算量(FLOPs),提升推理速度(如剪枝50%参数,性能仅下降1-3%)。

  • 减少内存占用:压缩模型体积,便于在边缘设备(如手机、IoT设备)部署。

  • 硬件友好性:结构化剪枝移除整个卷积核/通道,支持GPU/TPU加速。

  • 缓解过拟合:去除冗余参数可能增强泛化能力。

1、PyTorch完整实现:MNIST分类模型剪枝

  • 环境准备与数据加载
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune
  • 数据预处理
# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)
  • 定义全连接网络
class SimpleNet(nn.Module):def __init__(self):super(SimpleNet, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 784)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return F.log_softmax(x, dim=1)
  • 训练原始模型
def train(model, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % 200 == 0:print(f'Epoch {epoch} | Loss: {loss.item():.4f}')model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
  • 训练3个epoch

# 训练3个epoch
for epoch in range(1, 4):train(model, optimizer, epoch)
  • 应用L1非结构化剪枝
#应用L1非结构化剪枝
def apply_pruning(model, amount=0.3):# 全局剪枝:统一分配各层剪枝比例parameters_to_prune = ((model.fc1, 'weight'),(model.fc2, 'weight'),(model.fc3, 'weight'),)prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=amount)
  • 剪除30%权重
# 剪除30%权重
apply_pruning(model)
print(f"剪枝后稀疏度:fc1={100*torch.sum(model.fc1.weight==0)/model.fc1.weight.nelement():.1f}%")
  • 微调剪枝后模型,微调2个epoch
# 微调2个epoch
for epoch in range(1, 3):train(model, optimizer, epoch)
  • 性能评估
def test(model):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()accuracy = 100. * correct / len(test_loader.dataset)print(f'Test Accuracy: {accuracy:.2f}%')test(model)
  • 统计 稀疏性 函数
# 统计稀疏性
def print_sparsity(model):for name, param in model.named_parameters():if 'weight' in name:sparsity = 100 * (param == 0).sum().item() / param.numel()print(f"{name} 稀疏度: {sparsity:.2f}%")
  • 结构化剪枝实现,剪除卷积层50% weight; 使用prune.remove()永久移除被剪枝参数,生成稀疏模型。
#结构化剪枝实现(卷积层示例)
# 示例:对卷积层进行L2范数剪枝(移除50%输出通道)
conv = torch.nn.Conv2d(3, 64, kernel_size=3)
print_sparsity(conv)
prune.ln_structured(conv, name="weight", amount=0.5, n=2, dim=0)  # L2范数剪枝
prune.remove(conv, "weight")  # 永久移除剪枝掩码
print_sparsity(conv)

2、手动实现L1权重剪枝(剪除30%权重)

def manual_pruning(model, amount=0.3):for name, param in model.named_parameters():# print(name)if 'weight' in name and 'fc' in name:  # 仅处理全连接层的权重weights = param.data.abs()         # 取绝对值作为重要性度量threshold = torch.quantile(weights.flatten(), amount)  # 计算剪枝阈值mask = (weights > threshold).float()  # 生成掩码(保留高于阈值的权重)param.data *= mask                     # 应用掩码(剪枝)# 执行剪枝
manual_pruning(model, amount=0.3)# 统计稀疏性
def print_sparsity(model):for name, param in model.named_parameters():if 'weight' in name:sparsity = 100 * (param == 0).sum().item() / param.numel()print(f"{name} 稀疏度: {sparsity:.2f}%")print_sparsity(model)
http://www.xdnf.cn/news/583381.html

相关文章:

  • 硬件开发复盘实战指南
  • CTF签到题
  • 自制操作系统day8 (鼠标数据取得、通往32位模式之路、A20GATE、切换到保护模式、控制寄存器cr0-cr4以及cr8、ALIGNB)
  • 基于 AMDXCVU47P HBM2 FPGA 的 2 路 100G 光纤 PCIe 高性能计算加速卡
  • LabVIEW多通道液位监控
  • 框架开发与原生开发的权衡:React案例分析(原生JavaScript)
  • 【hadoop】Spark的安装部署
  • jvm安全点(五)openjdk17 c++源码垃圾回收之安全点阻塞状态线程在安全点同步中无需调用block函数的详细流程解析
  • Vue:axios(GET请求)
  • 【VLNs篇】04:SayNav-为新环境中的动态规划到导航进行大型语言模型的基础构建
  • 批量处理合并拆分pdf功能 OCR 准确率高 免费开源
  • 华为昇腾开发——多模型资源管理(C++)
  • Apollo10.0学习——planning模块(9)之参数详解二
  • WooCommerce缓存教程 – 如何防止缓存破坏你的WooCommerce网站?
  • 7.2.顺序查找
  • 黑马点评前端Nginx启动失败问题解决记录
  • day26- 系统编程之 文件IO(II) 及 文件属性
  • 数据结构:绪论之时间复杂度与空间复杂度
  • 论文阅读笔记——PixArt-α,PixArt-δ
  • 滚珠导轨:重构精密仪器传动架构,开启微纳世界
  • C++-继承
  • k8s容器入门(1)有状态服务 vs 无状态服务 核心区别
  • list(c++)
  • 排序和排列——蓝桥杯备考
  • 在Java的list.forEach(即 Stream API 的 forEach 方法)中,无法直接使用 continue 或 break 语句的解决办法
  • Lucide:一款精美的开源矢量图标库,前端图标新选择
  • 5G 核心网中的 NPN 功能详解
  • MongoDB大数据量的优化——mongoTemplate.stream()方法使用
  • 参与开发的注意事项
  • 每日算法-250522