当前位置: 首页 > backend >正文

深度学习中的模型剪枝工具Torch-Pruning的使用

      Torch-Pruning(TP)是一个结构化剪枝框架,源码地址:https://github.com/VainF/Torch-Pruning,最新发布版本v1.6.0,License为MIT。

      TP支持对各种深度神经网络进行结构化剪枝。与通过掩码将参数设置为零的torch.nn.utils.prune不同,TP部署了一种名为DepGraph的算法来分组和移除耦合参数(coupled parameter)。

      TP仅依赖PyTorch和Numpy,并且与PyTorch 1.x和2.x兼容,在Anaconda虚拟环境上通过pip安装v1.6.0版本,执行以下命令:

pip install torch-pruning==1.6.0

      在结构化剪枝中,移除单个参数可能会影响多个层。例如,剪枝线性层的输出维度将需要移除下一个线性层中相应的输入维度。层之间的这种依赖关系使得手动剪枝复杂网络变得非常困难。TP通过引入一种名为DepGraph的基于图的算法来解决这个问题,该算法可以自动识别依赖关系并收集需要剪枝的组。

      这里以 https://blog.csdn.net/fengbingchun/article/details/149307432 中的数据集为例,使用DenseNet进行分类,测试代码如下:

      1. 对之前生成的分类模型进行剪枝::保存剪枝后的模型使用torch.save(model,name),不能使用torch.save(model.state_dict(),name)

def model_pruning(model_name, classes_number, prune_amount):# https://github.com/VainF/Torch-Pruning/blob/master/examples/torchvision_models/torchvision_global_pruning.pymodel = models.densenet121(weights=None)model.classifier = nn.Linear(model.classifier.in_features, classes_number)# print("before pruning, model:", model)model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))orininal_size = tp.utils.count_params(model)model.cpu().eval()for p in model.parameters():p.requires_grad_(True)ignored_layers = []for m in model.modules():if isinstance(m, nn.Linear):ignored_layers.append(m)print(f"ignored_layers: {ignored_layers}")example_inputs = torch.randn(1, 3, 224, 224)# build network prunersimportance = tp.importance.MagnitudeImportance(p=1)pruner = tp.pruner.MagnitudePruner(model,example_inputs=example_inputs,importance=importance,iterative_steps=1,pruning_ratio=prune_amount,global_pruning=True,round_to=None,unwrapped_parameters=None,ignored_layers=ignored_layers,channel_groups={})# pruninglayer_channel_cfg = {}for module in model.modules():if module not in pruner.ignored_layers:if isinstance(module, nn.Conv2d):layer_channel_cfg[module] = module.out_channelselif isinstance(module, nn.Linear):layer_channel_cfg[module] = module.out_featurespruner.step()# print("after pruning, model", model)result_size = tp.utils.count_params(model)print(f"model: original size: {orininal_size}; result_size: {result_size}")# testingwith torch.no_grad():out = model(example_inputs)print("test out:", out)torch.save(model, "new_structured_prune_melon_classify.pt") # cann't bu used: torch.save(model.state_dict(), "")

      剪枝前后,模型的改动如下图所示:

      剪枝前模型大小约为27.1MB,剪枝后模型大小约为14.0M。

      2. 模型剪枝后需要对其进行微调,即重新训练:

def _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)train_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)val_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"val_loader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=0)return len(train_dataset), len(val_dataset), train_loader, val_loaderdef fine_tuning(dataset_path, epochs, mean, std, model_name):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)train_dataset_num, val_dataset_num, train_loader, val_loader = _load_dataset(dataset_path, mean, std, 4)optimizer = optim.Adam(model.parameters(), lr=0.00001) # set the optimizercriterion = nn.CrossEntropyLoss() # set the losshighest_accuracy = 0.minimum_loss = 100.new_model_name = "fine_tuning_melon_classify.pt"for epoch in range(epochs):epoch_start = time.time()train_loss = 0.0train_acc = 0.0val_loss = 0.0val_acc = 0.0model.train() # set to training modefor _, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() # clean existing gradientsoutputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossloss.backward() # backpropagate the gradientsoptimizer.step() # update the parameterstrain_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute the accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floattrain_acc += acc.item() * inputs.size(0) # compute the total accuracy# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")model.eval() # set to evaluation modewith torch.no_grad():for _, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossval_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute validation accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floatval_acc += acc.item() * inputs.size(0) # compute the total accuracyavg_train_loss = train_loss / train_dataset_num # average training lossavg_train_acc = train_acc / train_dataset_num # average training accuracyavg_val_loss = val_loss / val_dataset_num # average validation lossavg_val_acc = val_acc / val_dataset_num # average validation accuracyepoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(model, new_model_name)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.0001 or avg_val_acc > 0.9999:print(colorama.Fore.YELLOW + "stop training early")torch.save(model, new_model_name)break

      微调时迭代几次即可满足要求,执行结果如下图所示:

      3. 使用剪枝后的模型和微调后的模型进行预测::加载模型使用torch.load(model_name, weights_only=False),不能使用model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))

def _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef predict(model_name, labels_file, images_path, mean, std):classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"mean = _str2tuple(mean)std = _str2tuple(std)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std) # RGB])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on itmax_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")

      执行结果如下图所示:微调前的模型准确率非常低,微调后的模型准确率非常高

      GitHub:https://github.com/fengbingchun/NN_Test

http://www.xdnf.cn/news/15788.html

相关文章:

  • 如何解决AttributeError: ‘NoneType‘ object has no attribute问题
  • 使用 PlanetScope 卫星图像绘制水质参数:以莫干湖为例
  • 记录我coding印象比较深刻的BUG
  • 【Docker项目实战】使用Docker部署Homeland社区系统
  • 以太坊的心脏与大脑:详解执行客户端(EL)与共识客户端(CL)
  • 网络原理——TCP
  • node.js学习笔记1
  • 云边端协同架构下的智能计算革命
  • 解惑LINQ中的SelectMany用法
  • 一站式PDF转Markdown解决方案PDF3MD
  • 数据库第四次作业
  • Flexbox vs Float vs Table:现代布局终极对比
  • kombu 运行超长时间任务导致RabbitMQ消费者断开
  • (LeetCode 面试经典 150 题) 49. 字母异位词分组 (哈希表)
  • 基于Eureka和restTemple的负载均衡
  • buildroot运行qemu进行pcie设备模拟,开发驱动的方式
  • 【RK3576】【Android14】Android平台构建
  • 爬虫逆向之JS混淆案例(全国招标公告公示搜索引擎 type__1017逆向)
  • 重学Framework Input模块:如何实现按键一键启动Activity-学员作业
  • HTML5中的自定义属性
  • 【洛谷】询问学号、寄包柜、移动零、颜色分类(vector相关算法题p1)
  • 实验室危险品智能管控:行为识别算法降低爆炸风险
  • bws-rs:Rust 编写的 S3 协议网关框架,支持灵活后端接入
  • 汽车ECU控制器通信架构
  • Java学习--------消息队列的重复消费、消失与顺序性的深度解析​
  • Linux 内存管理(2):了解内存回收机制
  • Python实现智能文件搜索系统:从基础到高级应用
  • 【Oracle】ORACLE OMF说明
  • AUTOSAR进阶图解==>AUTOSAR_SWS_DiagnosticLogAndTrace
  • Redisson RLocalCachedMap 核心参详解