MediaPipe如何训练自己的手势数据
前言:由于Google上面提供的默认模型只包含7种手势,如何自定义自己的模型,虽然官网给了示例,但是是基于Google的平台Colab运行的,这个需要传文件到Google云盘,然后也比较麻烦,那么如何在本地运行,折腾了挺久,网上太多无效的文章,所以在这里分享。
官网地址为 手势识别模型自定义指南
1、首先安装依赖,如果报错可以见另外一篇文章:安装mediapipe-model-maker报错解决
本人在linux上使用python=3.10安装成功
2、查看模型接收的格式,可以运行官网示例,下载官网示例数据集进行查看。格式为每个文件夹下面是图片,不像yolo有label数据,这个很重要!rps_data_sample 数据集在我的资源免费下载 资源下载
3、编写代码
import os
from mediapipe_model_maker import gesture_recognizer# 1. 设置数据集路径(替换为你的实际路径,我用的Windows下面的wsl)
DATASET_PATH = "/mnt/d/xxx/rps_data_sample"# 2. 验证数据集结构
print("数据集路径:", DATASET_PATH)
labels = [label for label in os.listdir(DATASET_PATH)if os.path.isdir(os.path.join(DATASET_PATH, label))]
print("检测到的标签:", labels) # 应输出: ['none', 'paper', 'rock', 'scissors']# 3. 加载并分割数据集
data = gesture_recognizer.Dataset.from_folder(dirname=DATASET_PATH,hparams=gesture_recognizer.HandDataPreprocessingParams(shuffle=True,min_detection_confidence=0.5 # 手部检测置信度阈值)
)
train_data, rest_data = data.split(0.8) # 80%训练
validation_data, test_data = rest_data.split(0.5) # 10%验证, 10%测试# 4. 配置训练参数(自定义关键参数)
hparams = gesture_recognizer.HParams(epochs=15, # 增加训练轮次batch_size=16, # 根据GPU内存调整learning_rate=0.001,lr_decay=0.95, # 学习率衰减export_dir="rps_model" # 模型输出目录
)
options = gesture_recognizer.GestureRecognizerOptions(hparams=hparams,model_options=gesture_recognizer.ModelOptions(dropout_rate=0.1, # 防止过拟合layer_widths=[64, 32] # 添加2个隐藏层)
)# 5. 训练模型
print("\n开始训练模型...")
model = gesture_recognizer.GestureRecognizer.create(train_data=train_data,validation_data=validation_data,options=options
)# 6. 评估模型
print("\n评估模型性能:")
loss, accuracy = model.evaluate(test_data)
print(f"测试集损失: {loss:.4f}, 准确率: {accuracy*100:.2f}%")# 7. 导出TFLite模型(自动包含元数据)
model.export_model("rps_gesture.tflite")
print("\n模型已导出为 rps_gesture.tflite")# 8. 测试单张图片(可选)
test_image = os.path.join(DATASET_PATH, "rock\1.jpg") # 替换为你的测试图片
if os.path.exists(test_image):result = model.recognize(test_image)top_gesture = result.gestures[0][0]print(f"\n测试图片 '{test_image}' 的预测结果:")print(f"手势: {top_gesture.category_name}, 置信度: {top_gesture.score:.2%}")
else:print(f"\n测试图片不存在: {test_image}")
执行完成后,会在同级目录下生成rps_model 文件夹,训练好的模型如下
4、自定义数据集
整理好自己的数据集,修改数据集路径即可,hagrid-sample-30k-384p 数据集是hagrid的精简版,包含18种手势和无手势,可在我的资源里面进行下载