当前位置: 首页 > backend >正文

深度学习方法生成抓取位姿与6D姿态估计的完整实现

如何将GraspNet等深度学习模型与6D姿态估计集成到ROS2和MoveIt中,实现高精度的机器人抓取系统。

1. 系统架构

text

[RGB-D传感器] → [物体检测与6D姿态估计] → [GraspNet抓取位姿生成] → [MoveIt运动规划] → [执行抓取]

2. 环境配置

2.1 安装依赖

bash

# 安装PyTorch (根据CUDA版本选择)
pip3 install torch torchvision torchaudio# 安装其他依赖
pip3 install open3d scipy transforms3d tensorboardx# 安装ROS2相关包
sudo apt install ros-${ROS_DISTRO}-vision-msgs ros-${ROS_DISTRO}-tf2-geometry-msgs

2.2 下载GraspNet模型

bash

mkdir -p ~/graspnet_ws/src
cd ~/graspnet_ws/src
git clone https://github.com/jsll/pytorch-graspnet.git
cd pytorch-graspnet
pip install -e .

3. 6D姿态估计实现

3.1 创建姿态估计节点

src/pose_estimation.py:

python

#!/usr/bin/env python3
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image, CameraInfo, PointCloud2
from vision_msgs.msg import Detection3DArray, BoundingBox3D
from geometry_msgs.msg import Pose
from cv_bridge import CvBridge
import cv2
import numpy as np
import torch
from third_party.DenseFusion.lib.network import PoseNet, PoseRefineNet
from third_party.DenseFusion.lib.utils import *class PoseEstimator(Node):def __init__(self):super().__init__('pose_estimator')# 加载预训练模型self.model = PoseNet(num_points=1000, num_obj=10)self.model.load_state_dict(torch.load('path/to/pose_model.pth'))self.model.cuda()self.refiner = PoseRefineNet(num_points=1000, num_obj=10)self.refiner.load_state_dict(torch.load('path/to/pose_refine_model.pth'))self.refiner.cuda()# 订阅RGB-D数据self.sub_rgb = self.create_subscription(Image, '/camera/color/image_raw', self.rgb_callback, 10)self.sub_depth = self.create_subscription(Image, '/camera/depth/image_raw', self.depth_callback, 10)self.sub_camera_info = self.create_subscription(CameraInfo, '/camera/color/camera_info', self.camera_info_callback, 10)# 发布检测结果self.pub_detections = self.create_publisher(Detection3DArray, '/object_detections_3d', 10)self.bridge = CvBridge()self.camera_matrix = Noneself.dist_coeffs = Noneself.current_rgb = Noneself.current_depth = Nonedef camera_info_callback(self, msg):# 获取相机内参self.camera_matrix = np.array(msg.k).reshape(3, 3)self.dist_coeffs = np.array(msg.d)def rgb_callback(self, msg):self.current_rgb = self.bridge.imgmsg_to_cv2(msg, 'bgr8')def depth_callback(self, msg):self.current_depth = self.bridge.imgmsg_to_cv2(msg, desired_encoding='passthrough')if self.current_rgb is not None and self.camera_matrix is not None:self.process_frame()def process_frame(self):# 1. 物体检测 (使用YOLOv5或Mask R-CNN)detections = self.detect_objects(self.current_rgb)# 2. 对每个检测到的物体估计6D姿态detection_array = Detection3DArray()detection_array.header.stamp = self.get_clock().now().to_msg()detection_array.header.frame_id = 'camera_color_optical_frame'for det in detections:# 提取ROIroi = self.extract_roi(det['bbox'])# 估计初始姿态pose = self.estimate_pose(roi)# 姿态优化refined_pose = self.refine_pose(pose, roi)# 创建检测消息detection = Detection3D()detection.bbox.center.position.x = refined_pose[0]detection.bbox.center.position.y = refined_pose[1]detection.bbox.center.position.z = refined_pose[2]detection.bbox.size.x = det['width']detection.bbox.size.y = det['height']detection.bbox.size.z = det['depth']# 设置姿态 (四元数表示)q = euler_to_quaternion(refined_pose[3:])detection.bbox.center.orientation.x = q[0]detection.bbox.center.orientation.y = q[1]detection.bbox.center.orientation.z = q[2]detection.bbox.center.orientation.w = q[3]detection_array.detections.append(detection)# 发布检测结果self.pub_detections.publish(detection_array)def detect_objects(self, image):# 这里实现物体检测逻辑# 返回包含bbox和类别的列表passdef extract_roi(self, bbox):# 根据bbox提取感兴趣区域passdef estimate_pose(self, roi):# 使用DenseFusion估计初始姿态passdef refine_pose(self, initial_pose, roi):# 使用refiner网络优化姿态passdef main(args=None):rclpy.init(args=args)node = PoseEstimator()rclpy.spin(node)node.destroy_node()rclpy.shutdown()if __name__ == '__main__':main()

4. GraspNet集成

4.1 创建抓取位姿生成节点

src/graspnet_node.py:

python

#!/usr/bin/env python3
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import PointCloud2
from vision_msgs.msg import Detection3DArray
from geometry_msgs.msg import PoseArray, Pose
from graspnetAPI import GraspNet
import numpy as np
import open3d as o3d
from cv_bridge import CvBridgeclass GraspNetNode(Node):def __init__(self):super().__init__('graspnet_node')# 初始化GraspNetself.graspnet = GraspNet(root='/path/to/graspnet_dataset', camera='realsense', split='seen')# 订阅点云和检测结果self.sub_pc = self.create_subscription(PointCloud2, '/camera/depth/points', self.pc_callback, 10)self.sub_detections = self.create_subscription(Detection3DArray, '/object_detections_3d', self.detection_callback, 10)# 发布抓取位姿self.pub_grasps = self.create_publisher(PoseArray, '/grasp_poses', 10)self.bridge = CvBridge()self.current_pc = Noneself.current_detections = Nonedef pc_callback(self, msg):# 转换点云为Open3D格式self.current_pc = self.pointcloud2_to_o3d(msg)def detection_callback(self, msg):self.current_detections = msgif self.current_pc is not None:self.process_detections()def process_detections(self):grasp_poses = PoseArray()grasp_poses.header = self.current_detections.headerfor detection in self.current_detections.detections:# 1. 提取物体点云obj_pc = self.crop_pointcloud(detection)# 2. 生成抓取位姿grasps = self.graspnet.get_grasps(obj_pc)# 3. 过滤和排序抓取位姿valid_grasps = self.filter_grasps(grasps)# 4. 转换为ROS消息for grasp in valid_grasps[:5]:  # 取前5个最佳抓取pose = Pose()pose.position.x = grasp.translation[0]pose.position.y = grasp.translation[1]pose.position.z = grasp.translation[2]# 抓取方向 (转换为四元数)pose.orientation = self.matrix_to_quaternion(grasp.rotation_matrix)grasp_poses.poses.append(pose)# 发布抓取位姿self.pub_grasps.publish(grasp_poses)def pointcloud2_to_o3d(self, msg):# 转换ROS PointCloud2为Open3D点云passdef crop_pointcloud(self, detection):# 根据检测框裁剪点云passdef filter_grasps(self, grasps):# 根据抓取质量分数过滤和排序return sorted(grasps, key=lambda x: x.score, reverse=True)def matrix_to_quaternion(self, matrix):# 转换旋转矩阵为四元数passdef main(args=None):rclpy.init(args=args)node = GraspNetNode()rclpy.spin(node)node.destroy_node()rclpy.shutdown()if __name__ == '__main__':main()

5. MoveIt集成

5.1 创建抓取执行节点

src/grasp_executor.py:

python

#!/usr/bin/env python3
import rclpy
from rclpy.node import Node
from geometry_msgs.msg import PoseArray, PoseStamped
from moveit_msgs.msg import CollisionObject
from shape_msgs.msg import SolidPrimitive
from moveit_msgs.srv import GetMotionPlan
from moveit import MoveGroupInterfaceclass GraspExecutor(Node):def __init__(self):super().__init__('grasp_executor')# 订阅抓取位姿self.sub_grasps = self.create_subscription(PoseArray, '/grasp_poses', self.grasp_callback, 10)# MoveIt接口self.move_group = MoveGroupInterface(self, "arm_group", "robot_description")# 规划场景接口self.planning_scene = PlanningSceneInterface(self)def grasp_callback(self, msg):if not msg.poses:return# 1. 选择最佳抓取位姿target_grasp = self.select_best_grasp(msg.poses)# 2. 添加障碍物到规划场景self.add_obstacles()# 3. 规划并执行抓取self.execute_grasp(target_grasp)def select_best_grasp(self, poses):# 这里可以实现更复杂的选择逻辑return poses[0]def add_obstacles(self):# 添加桌面和其他障碍物table = CollisionObject()table.id = "table"table.header.frame_id = "world"table_primitive = SolidPrimitive()table_primitive.type = SolidPrimitive.BOXtable_primitive.dimensions = [1.0, 1.0, 0.02]  # 尺寸table_pose = PoseStamped()table_pose.header.frame_id = "world"table_pose.pose.position.z = -0.01  # 稍微低于原点table.primitives.append(table_primitive)table.primitive_poses.append(table_pose.pose)self.planning_scene.add_object(table)def execute_grasp(self, grasp_pose):# 1. 创建预抓取位姿 (后退10cm)pregrasp_pose = PoseStamped()pregrasp_pose.header = grasp_pose.headerpregrasp_pose.pose = grasp_pose.posepregrasp_pose.pose.position.z += 0.1# 2. 移动到预抓取位置self.move_group.set_pose_target(pregrasp_pose)self.move_group.go(wait=True)# 3. 移动到抓取位置self.move_group.set_pose_target(grasp_pose)self.move_group.go(wait=True)# 4. 关闭夹爪self.close_gripper()# 5. 撤退到预抓取位置self.move_group.set_pose_target(pregrasp_pose)self.move_group.go(wait=True)def close_gripper(self):# 实现夹爪控制逻辑passdef main(args=None):rclpy.init(args=args)node = GraspExecutor()rclpy.spin(node)node.destroy_node()rclpy.shutdown()if __name__ == '__main__':main()

6. 系统集成与优化

6.1 启动文件

launch/graspnet_grasping.launch.py:

python

from launch import LaunchDescription
from launch_ros.actions import Nodedef generate_launch_description():return LaunchDescription([# 奥比中光摄像头Node(package='astra_camera',executable='astra_camera_node',name='camera',parameters=[{'depth_registration': True,'depth_mode': '640x480','color_mode': '640x480',}]),# 6D姿态估计节点Node(package='graspnet_demo',executable='pose_estimation.py',name='pose_estimator'),# GraspNet节点Node(package='graspnet_demo',executable='graspnet_node.py',name='graspnet_node'),# 抓取执行节点Node(package='graspnet_demo',executable='grasp_executor.py',name='grasp_executor'),# MoveItNode(package='moveit_ros_move_group',executable='move_group',name='move_group',parameters=[{'robot_description': '/robot_description'}])])

6.2 性能优化技巧

  1. 并行处理:

    • 使用ROS2的异步回调组处理图像和点云

    • 在多GPU上并行运行姿态估计和抓取生成

  2. 模型优化:

    • 使用TensorRT加速推理

    • 量化模型减少内存占用

  3. 缓存机制:

    • 对静态物体缓存抓取位姿

    • 实现增量式姿态更新

  4. 实时性优化:

    • 使用低分辨率输入进行快速估计

    • 实现多级抓取位姿生成 (快速生成+精细优化)

7. 评估与调试

  1. 可视化工具:

    bash

  1. # RViz可视化
    ros2 run rviz2 rviz2 -d $(ros2 pkg prefix graspnet_demo)/share/graspnet_demo/config/graspnet.rviz# 实时监控
    ros2 topic hz /grasp_poses
  2. 评估指标:

    • 姿态估计精度 (ADD/ADD-S指标)

    • 抓取成功率

    • 从检测到执行的端到端延迟

  3. 调试技巧:

    • 使用ros2 bag record记录问题场景

    • 实现可视化调试标记

    • 添加详细的日志级别控制

通过这种集成方法,你可以构建一个结合了最先进深度学习技术和传统运动规划的智能抓取系统,能够处理复杂场景下的各种物体抓取任务。

http://www.xdnf.cn/news/15875.html

相关文章:

  • Python应用进阶DAY10--模块化编程概念(模块、包、导入)及常见系统模块总结和第三方模块管理
  • 设计模式笔记(1)简单工厂模式
  • 【图论】图的定义与一些常用术语
  • thinkphp8\guzzlehttp上传文件应用示例
  • Linux基础命令详解:从入门到精通
  • prometheus 黑盒监控和docker检测
  • git操作
  • Node.js:常用工具、GET/POST请求的写法、工具模块
  • ByteBuf 体系的设计与实现
  • `tidyverse` 长表、宽表的处理
  • 【HarmonyOS】ArkUI - 自定义组件和结构重用
  • 处理Electron Builder 创建新进程错误 spawn ENOMEM
  • Spring AI 聊天记忆
  • 28.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--单体转微服务--币种服务(二)
  • Spring Boot 配置文件解析
  • SpringBoot集成MyBatis的SQL拦截器实战
  • Java学习第六十部分——JVM
  • [硬件电路-52]:什么是模拟电路与数字电路;它们的共同点、核心差异点(原理、目标、关注点等)以及它们如何相互转化
  • LeetCode 852:山脉数组的峰顶索引解析与实现
  • STM32CubeMX的一些操作步骤的作用
  • 7-20 关于mysql
  • 网络安全隔离技术解析:从网闸到光闸的进化之路
  • 【硬件】GalaxyTabPro10.1(SM-T520)刷机/TWRP/LineageOS14/安卓7升级小白向保姆教程
  • RxSwift-事件属性
  • JVM-Java
  • LINUX(三)文件I/O、对文件打开、读、写、偏移量
  • 股票及金融笔记
  • 使用Qt6 QML/C++ 和CMake构建海康威视摄像头应用(代码开源)
  • 双8无碳小车“cad【17张】三维图+设计说名书
  • 【橘子分布式】gRPC(编程篇-下)