当前位置: 首页 > ops >正文

【免费可用】【提供源代码】对YOLOV11模型进行剪枝和蒸馏

yolov11_prune_distillation

该项目可以用于YOLOv11网络的训练,静态剪枝和知识蒸馏。可以在减少模型参数量的同时,尽量保证模型的推理精度。

Github链接:https://github.com/zhahoi/yolov11_prune_distillation.git

🤗Current Ultralytics version: 8.3.160

🔧 Install Dependencies

pip install torch-pruning 
pip install -r requirements.txt

🚂 Training & Pruning & Knowledge Distillation

📊 YOLO11 Training Example

### train.py
from ultralytics import YOLOif __name__ == "__main__":model = YOLO('yolo11.yaml')results = model.train(data='uno.yaml', epochs=100, imgsz=640, batch=8, device="0", name='yolo11', workers=0, prune=False)

✂️ YOLO11 Pruning Example

### prune.py
from ultralytics import YOLO# model = YOLO('yolo11.yaml')
model = YOLO('runs/detect/yolo11/weights/best.pt')def prunetrain(train_epochs, prune_epochs=0, quick_pruning=True, prune_ratio=0.5, prune_iterative_steps=1, data='coco.yaml', name='yolo11', imgsz=640, batch=8, device=[0], sparse_training=False):if not quick_pruning:assert train_epochs > 0 and prune_epochs > 0, "Quick Pruning is not set. prune epochs must > 0."print("Phase 1: Normal training...")model.train(data=data, epochs=train_epochs, imgsz=imgsz, batch=batch, device=device, name=f"{name}_phase1", prune=False,sparse_training=sparse_training)print("Phase 2: Pruning training...")best_weights = f"runs/detect/{name}_phase1/weights/best.pt"pruned_model = YOLO(best_weights)return pruned_model.train(data=data, epochs=prune_epochs, imgsz=imgsz, batch=batch, device=device, name=f"{name}_pruned", prune=True,prune_ratio=prune_ratio, prune_iterative_steps=prune_iterative_steps)else:return model.train(data=data, epochs=train_epochs, imgsz=imgsz, batch=batch, device=device, name=name, prune=True, prune_ratio=prune_ratio, prune_iterative_steps=prune_iterative_steps)if __name__ == '__main__':# Normal Pruningprunetrain(quick_pruning=False,       # Quick Pruning or notdata='uno.yaml',          # Dataset configtrain_epochs=10,           # Epochs before pruningprune_epochs=20,           # Epochs after pruning imgsz=640,                 # Input sizebatch=8,                   # Batch sizedevice=[0],                # GPU devicesname='yolo11_prune',             # Save nameprune_ratio=0.5,           # Pruning Ratio (50%)prune_iterative_steps=1,   # Pruning Interative Stepssparse_training=True      # Experimental, Allow Sparse Training Before Pruning)# Quick Pruning (prune_epochs no need)# prunetrain(quick_pruning=True, data='coco.yaml', train_epochs=10, imgsz=640, batch=8, device=[0], name='yolo11', #            prune_ratio=0.5, prune_iterative_steps=1)

🔎 YOLO11 Knowledge Distillation Example

### knowledge_distillation.py
from ultralytics import YOLO
from ultralytics.nn.attention.attention import ParallelPolarizedSelfAttention
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.utils.torch_utils import model_infodef add_attention(model):at0 = model.model.model[4]n0 = at0.cv2.conv.out_channelsat0.attention = ParallelPolarizedSelfAttention(n0)at1 = model.model.model[6]n1 = at1.cv2.conv.out_channelsat1.attention = ParallelPolarizedSelfAttention(n1)at2 = model.model.model[8]n2 = at2.cv2.conv.out_channelsat2.attention = ParallelPolarizedSelfAttention(n2)return modelif __name__ == "__main__":# layers = ["6", "8", "13", "16", "19", "22"]layers = ["4", "6", "10", "16", "19", "22"]model_t = YOLO('runs/detect/yolo11/weights/best.pt')  # the teacher modelmodel_s = YOLO("runs/detect/yolo11_prune_pruned/weights/best.pt")  # the student modelmodel_s = add_attention(model_s) # Add attention to the student model# configure overridesoverrides = {"model": "runs/detect/yolo11_prune_pruned/weights/best.pt","Distillation": model_t.model,"loss_type": "mgd","layers": layers,"epochs": 50,"imgsz": 640,"batch": 8,"device": 0,"lr0": 0.001,"amp": False,"sparse_training": False,"prune": False,"prune_load": False,"workers": 0,"data": "data.yaml","name": "yolo11_distill"}trainer = DetectionTrainer(overrides=overrides)trainer.model = model_s.model model_info(trainer.model, verbose=True)trainer.train()

📤 Model Export

Export to ONNX Format Example

### export.py
from ultralytics import YOLOmodel = YOLO('runs/detect/yolo11_distill/weights/yolo11n.pt')
print(model.model)
model.export(format='onnx')

🌞 Model Inference

Image Inference Example

### infer.py
from ultralytics import YOLO
model = YOLO('runs/detect/yolo11/weights/best.pt') # model = YOLO('prune.pt')
model.predict('fruits.jpg', save=True, device=[0], line_width=2)

🔢 Model Analysis

Use thop to easily calculate model parameters and FLOPs:

pip install thop

You can calculate model parameters and flops by using calculate.py

🤝 Contributing & Support

Feel free to submit issues or pull requests on GitHub for questions or suggestions!

📚 Acknowledgements

  • Special thanks to @VainF for the contribution to the Torch-Pruning project! This project relies on it for model pruning.
  • Special thanks to @Ultralytics for the contribution to the ultralytics project! This project relies on it for the framework.
  • YOLO-Pruning-RKNN
  • yolov11_prune_distillation_v2
http://www.xdnf.cn/news/16448.html

相关文章:

  • 跨境协作系统文化适配:多语言环境下的业务符号隐喻与交互习惯
  • Java项目:基于SSM框架实现的社区团购管理系统【ssm+B/S架构+源码+数据库+毕业论文+答辩PPT+远程部署】
  • Nuxt3 全栈作品【通用信息管理系统】修改密码
  • 亚远景-“过度保守”还是“激进创新”?ISO/PAS 8800的99.9%安全阈值之争
  • 【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博文章数据可视化分析-点赞区间实现
  • 【HTTP】防XSS+SQL注入:自定义HttpMessageConverter过滤链深度解决方案
  • 【数据标注】详解使用 Labelimg 进行数据标注的 Conda 环境搭建与操作流程
  • 572. 另一棵树的子树
  • 电子签章(PDF)
  • 【0基础PS】PS工具详解--选择工具--对象选择工具
  • 【Linux | 网络】传输层(UDP和TCP) - 两万字详细讲解!!
  • 利用软件定义无线USRP X410、X440 电推进无线原型设计
  • ksql连接数据库免输入密码交互
  • 设计模式(十四)行为型:职责链模式详解
  • 飞牛NAS本地化部署n8n打造个人AI工作流中心
  • 【Java系统接口幂等性解决实操】
  • SpringSecurity实战:核心配置技巧
  • 记录几个SystemVerilog的语法——时钟块和进程通信
  • 盛最多水的容器-leetcode
  • 洛谷 P10446 64位整数乘法-普及-
  • 详解力扣高频SQL50题之1164. 指定日期的产品价格【中等】
  • 3,Windows11安装docker保姆级教程
  • LeetCode 76:最小覆盖子串
  • mybatis的insert(pojo),会返回pojo吗
  • Petalinux生成文件的关系
  • 力扣面试150题--二进制求和
  • mmap机制
  • 2.qt调试日志输出
  • 《C++》STL--string详解(上)
  • vue3报错:this.$refs.** undefined