当前位置: 首页 > java >正文

MediaPipe如何训练自己的手势数据

前言:由于Google上面提供的默认模型只包含7种手势,如何自定义自己的模型,虽然官网给了示例,但是是基于Google的平台Colab运行的,这个需要传文件到Google云盘,然后也比较麻烦,那么如何在本地运行,折腾了挺久,网上太多无效的文章,所以在这里分享。

官网地址为 手势识别模型自定义指南

在这里插入图片描述

1、首先安装依赖,如果报错可以见另外一篇文章:安装mediapipe-model-maker报错解决
本人在linux上使用python=3.10安装成功

2、查看模型接收的格式,可以运行官网示例,下载官网示例数据集进行查看。格式为每个文件夹下面是图片,不像yolo有label数据,这个很重要!rps_data_sample 数据集在我的资源免费下载 资源下载

在这里插入图片描述

3、编写代码

import os
from mediapipe_model_maker import gesture_recognizer# 1. 设置数据集路径(替换为你的实际路径,我用的Windows下面的wsl)
DATASET_PATH = "/mnt/d/xxx/rps_data_sample"# 2. 验证数据集结构
print("数据集路径:", DATASET_PATH)
labels = [label for label in os.listdir(DATASET_PATH)if os.path.isdir(os.path.join(DATASET_PATH, label))]
print("检测到的标签:", labels)  # 应输出: ['none', 'paper', 'rock', 'scissors']# 3. 加载并分割数据集
data = gesture_recognizer.Dataset.from_folder(dirname=DATASET_PATH,hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True,min_detection_confidence=0.5  # 手部检测置信度阈值)
)
train_data, rest_data = data.split(0.8)  # 80%训练
validation_data, test_data = rest_data.split(0.5)  # 10%验证, 10%测试# 4. 配置训练参数(自定义关键参数)
hparams = gesture_recognizer.HParams(epochs=15,               # 增加训练轮次batch_size=16,           # 根据GPU内存调整learning_rate=0.001,lr_decay=0.95,           # 学习率衰减export_dir="rps_model"   # 模型输出目录
)
options = gesture_recognizer.GestureRecognizerOptions(hparams=hparams,model_options=gesture_recognizer.ModelOptions(dropout_rate=0.1,     # 防止过拟合layer_widths=[64, 32] # 添加2个隐藏层)
)# 5. 训练模型
print("\n开始训练模型...")
model = gesture_recognizer.GestureRecognizer.create(train_data=train_data,validation_data=validation_data,options=options
)# 6. 评估模型
print("\n评估模型性能:")
loss, accuracy = model.evaluate(test_data)
print(f"测试集损失: {loss:.4f}, 准确率: {accuracy*100:.2f}%")# 7. 导出TFLite模型(自动包含元数据)
model.export_model("rps_gesture.tflite")
print("\n模型已导出为 rps_gesture.tflite")# 8. 测试单张图片(可选)
test_image = os.path.join(DATASET_PATH, "rock\1.jpg")  # 替换为你的测试图片
if os.path.exists(test_image):result = model.recognize(test_image)top_gesture = result.gestures[0][0]print(f"\n测试图片 '{test_image}' 的预测结果:")print(f"手势: {top_gesture.category_name}, 置信度: {top_gesture.score:.2%}")
else:print(f"\n测试图片不存在: {test_image}")

执行完成后,会在同级目录下生成rps_model 文件夹,训练好的模型如下

在这里插入图片描述

4、自定义数据集
整理好自己的数据集,修改数据集路径即可,hagrid-sample-30k-384p 数据集是hagrid的精简版,包含18种手势和无手势,可在我的资源里面进行下载

http://www.xdnf.cn/news/13628.html

相关文章:

  • Java异步编程:提升性能的实战秘籍
  • TruBit Pro:深化全球布局,拓展战略合作
  • (十三)计算机视觉中的深度学习:特征表示、模型架构与视觉认知原理
  • node-red的http-request组件调研三方接口请求参数为form-data解决方案
  • 数据分析入门初解
  • AbMole| Angiotensin II human(M6240;血管紧张素Ⅱ)
  • Vue.js 中 “require is not defined“
  • 大模型面试题:多模态处理多分辨率输入有哪些方法?
  • SpringMVC与Struts2对比教学
  • DeepSeek 助力 Vue3 开发:打造丝滑的日历(Calendar),日历_学习计划日历示例(CalendarView01_20)
  • 【React】常用的状态管理库比对
  • 短剧系统开发:打造高效、创新的短视频娱乐平台 - 从0到1的完整解决方案
  • [行为型模式]观察者模式
  • 【苍穹外卖项目】Day01
  • Django(自用)
  • Redis:渐进式遍历
  • ArkUI-X构建Android平台AAR及使用
  • ROS2 工作空间中, CMakeLists.txt, setup.py和 package.xml的作用分别是?
  • 【编译原理】题目合集(一)
  • 初识MySQL · 事务 · 下
  • TCP/IP 网络编程 | Reactor事件处理模式
  • 像素跟踪 跟踪像素 算法总结
  • 【慧游鲁博】【12】小程序端 · 智能导览对接后端文物图片识别功能
  • WEB JWT
  • java复习 09
  • 【开源工具】:基于PyQt5的智能网络驱动器映射工具开发全流程(附源码)
  • WWDC 2025 开发者特辑 | 肘子的 Swift 周报 #088
  • 计算机视觉之三维重建(深入浅出SfM与SLAM核心算法)—— 1. 摄像机几何
  • 2025最新软件测试八股文,查漏补缺(含答案+文档)
  • Spring Cloud Gateway 介绍