当前位置: 首页 > web >正文

测试ppyoloe的小样本few-shot能力,10张图片精度达到69.8%

近期公司有个项目,需要解决长尾样本的问题,所以测试了一下paddlepaddle小样本的能力。

环境::T4  、ubuntu 、cuda-11.6 、py3.9、   paddlepaddle-gpu==2.6.0、pip install opencv-python==4.5.5.64 -i https://pypi.tuna.tsinghua.edu.cn/simple    、 pip install  numpy==1.23.0

预训练模型:ppyoloe_crn_s_obj365_pretrained.pdparams

数据集下载地址:五种水果目标检测数据集coco格式_数据集-飞桨AI Studio星河社区


1、数据集准备五种水果:蕃茄、核桃、桔子、龙眼、青枣。共300张图像,640*480.COCO格式

2、先正常训练一波
数据如下:165步0.735的%表现

3、用脚本每个coco类别从原train.json提取10张图片,代码:
 

import json
from collections import defaultdict
import argparse
import osdef create_small_sample_coco(original_json, output_json, samples_per_class=10):"""从COCO格式的标注文件中,为每个类别提取指定数量的样本,并生成新的COCO标注文件参数:original_json (str): 原始COCO标注文件路径output_json (str): 输出的小样本COCO标注文件路径samples_per_class (int): 每个类别提取的样本数量"""# 加载原始标注数据with open(original_json, 'r', encoding='utf-8') as f:coco_data = json.load(f)# 确保必要的字段存在,不存在则添加默认值required_fields = {'info': {'description': 'Small sample dataset'},'licenses': [{'id': 0, 'name': 'Unknown'}],'categories': [],'images': [],'annotations': []}for field, default in required_fields.items():if field not in coco_data:print(f"警告: 标注文件缺少 '{field}' 字段,将使用默认值")coco_data[field] = default# 1. 统计每个类别的标注数量category_counts = defaultdict(int)for ann in coco_data['annotations']:cat_id = ann['category_id']category_counts[cat_id] += 1# 检查是否有类别if not category_counts:print("错误: 标注文件中未找到任何类别或标注")return# 2. 为每个类别选择指定数量的样本selected_images = set()  # 存储被选中的image_idcategory_samples = defaultdict(int)  # 记录每个类别已选择的样本数for ann in coco_data['annotations']:cat_id = ann['category_id']img_id = ann['image_id']# 如果该类别已选样本数不足,且该图片尚未被选中if category_samples[cat_id] < samples_per_class and img_id not in selected_images:selected_images.add(img_id)category_samples[cat_id] += 1# 检查是否所有类别都已选够样本if all(count >= samples_per_class for count in category_samples.values()):break# 3. 筛选出被选中的图片及其标注filtered_images = [img for img in coco_data['images'] if img['id'] in selected_images]filtered_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] in selected_images]# 4. 构建新的COCO数据集small_coco = {'info': coco_data['info'],'licenses': coco_data['licenses'],'categories': coco_data['categories'],'images': filtered_images,'annotations': filtered_annotations}# 5. 保存新的标注文件with open(output_json, 'w', encoding='utf-8') as f:json.dump(small_coco, f, indent=2)# 打印统计信息print(f"成功创建小样本数据集!")print(f"原始图片数量: {len(coco_data['images'])}")print(f"筛选后图片数量: {len(filtered_images)}")print(f"每个类别样本数: {samples_per_class}")print(f"保存路径: {output_json}")# 检查每个类别的实际样本数actual_counts = defaultdict(int)for ann in filtered_annotations:actual_counts[ann['category_id']] += 1# 映射类别ID到类别名称id_to_name = {cat['id']: cat['name'] for cat in coco_data['categories']}print("\n每个类别的实际样本数:")for cat_id, count in actual_counts.items():cat_name = id_to_name.get(cat_id, f"类别_{cat_id}")print(f"  {cat_name} (ID:{cat_id}): {count}个样本")if __name__ == "__main__":parser = argparse.ArgumentParser(description='从COCO数据集中创建小样本数据集')parser.add_argument('--input', '-i', required=True, help='原始COCO标注文件路径')parser.add_argument('--output', '-o', required=True, help='输出的小样本COCO标注文件路径')parser.add_argument('--samples', '-s', type=int, default=10, help='每个类别提取的样本数,默认为10')args = parser.parse_args()# 检查输入文件是否存在if not os.path.exists(args.input):print(f"错误: 输入文件 '{args.input}' 不存在")exit(1)# 检查输出目录是否存在,不存在则创建output_dir = os.path.dirname(args.output)if output_dir and not os.path.exists(output_dir):os.makedirs(output_dir)create_small_sample_coco(args.input, args.output, args.samples)

4、再次训练

python tools/train.py -c configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml  --amp  --eval --use_vdl=True --vdl_log_dir=./visdrone/

在39步精度达到0.69%

5、预测一下

python tools/infer.py -c configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml -o weights=output1/best_model.pdparams --infer_img=/home/PaddleDetection/dataset/coco/fruit5_coco/images/106.jpg

6、训练配置

_BASE_: ['../datasets/coco_detection.yml','../runtime.yml','./_base_/optimizer_80e.yml','./_base_/ppyoloe_plus_crn.yml','./_base_/ppyoloe_plus_reader.yml',
]log_iter: 100
snapshot_epoch: 5
weights: output/ppyoloe_plus_crn_s_80e_contrast_pcb/model_finalpretrain_weights: ./ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50epoch: 190LearningRate:base_lr: 0.0001schedulers:- !CosineDecaymax_epochs: 596- !LinearWarmupstart_factor: 0.epochs: 5YOLOv3:backbone: CSPResNetneck: CustomCSPPANyolo_head: PPYOLOEContrastHeadpost_process: ~PPYOLOEContrastHead:fpn_strides: [32, 16, 8]grid_cell_scale: 5.0grid_cell_offset: 0.5static_assigner_epoch: 100use_varifocal_loss: Trueloss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, contrast: 0.2}static_assigner:name: ATSSAssignertopk: 9assigner:name: TaskAlignedAssignertopk: 13alpha: 1.0beta: 6.0contrast_loss:name: SupContrasttemperature: 100sample_num: 2048thresh: 0.75nms:name: MultiClassNMSnms_top_k: 1000keep_top_k: 300score_threshold: 0.01nms_threshold: 0.7num_classes: 5
metric: COCO
map_type: integralTrainDataset:!COCODataSetimage_dir: imagesanno_path: /home/PaddleDetection/dataset/small.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']EvalDataset:!COCODataSetimage_dir: imagesanno_path: /home/PaddleDetection/dataset/coco/fruit5_coco/annotations/instance_val.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/TestDataset:!ImageFolderanno_path: /home/PaddleDetection/dataset/coco/fruit5_coco/annotations/instance_val.jsondataset_dir: /home/PaddleDetection/dataset/coco/fruit5_coco/

http://www.xdnf.cn/news/16389.html

相关文章:

  • Allegro软件光绘文件Artwork到底如何配置?
  • Python柱状图
  • Lakehouse x AI ,打造智能 BI 新体验
  • 戴尔电脑 Linux 安装与配置指南_导入mysql共享文件夹
  • 关于网络模型
  • FreeRTOS—优先级翻转问题
  • vue项目入门
  • 【C++避坑指南】vector迭代器失效的八大场景与解决方案
  • haproxy七层代理(原理)
  • 从0开始学习R语言--Day57--SCAD模型
  • 深入浅出设计模式——创建型模式之简单工厂模式
  • Hive【Hive架构及工作原理】
  • 如何高效通过3GPP官网查找资料
  • JAVA + 海康威视SDK + FFmpeg+ SRS 实现海康威视摄像头二次开发
  • 服务器托管:网站经常被攻击该怎么办?
  • 学习游戏制作记录(克隆技能)7.25
  • 秋招Day19 - 分布式 - 分布式锁
  • 初识决策树-理论部分
  • 肺癌预测模型实战案例
  • 【自动化运维神器Ansible】Ansible常用模块之Copy模块详解
  • 文件包含学习总结
  • 滑动窗口-7
  • 主要分布在背侧海马体(dHPC)CA1区域(dCA1)的时空联合细胞对NLP中的深层语义分析的积极影响和启示
  • ClickHouse 常用的使用场景
  • AWS WebRTC:我们的业务模式
  • [python][flask]flask蓝图使用方法
  • 【软件工程】构建软件合规防护网:双阶段检查机制的实践之道
  • Android studio自带的Android模拟器都是x86架构的吗,需要把arm架构的app翻译成x86指令?
  • FP16 和 BF16
  • 函数-变量的作用域和生命周期