当前位置: 首页 > news >正文

使用python进行船舶轨迹跟踪

一、系统概述

该系统基于 YOLOv8 深度学习模型和计算机视觉技术,实现对视频或摄像头画面中的船舶进行实时检测、跟踪,并计算船舶航向。支持透视变换校准(鸟瞰图显示)、多目标跟踪、轨迹存储及视频录制功能,适用于港口监控、航道管理、船舶行为分析等场景。

二、依赖库

python

运行

import cv2         # 计算机视觉处理(OpenCV库)
import numpy as np # 数值计算
import time        # 时间处理
import os          # 文件与目录操作
from datetime import datetime # 日期时间处理
from ultralytics import YOLO  # YOLOv8深度学习模型

三、类定义:ShipTracker

3.1 构造函数 __init__

功能

初始化船舶跟踪器,配置视频源、输出参数、YOLOv8 模型及跟踪参数。

参数说明
参数名类型默认值描述
video_sourceint/str0视频源(0为默认摄像头,或指定视频文件路径)
save_videoboolFalse是否保存处理后的视频
show_warpedboolTrue是否显示透视变换后的鸟瞰图
model_pathstryolov8n.ptYOLOv8 模型路径(默认为 COCO 预训练的小模型)
内部属性
  • 视频源与基础参数
    • cap:视频捕获对象(cv2.VideoCapture实例)
    • frame_width/frame_height:视频帧宽高
    • fps:帧率
  • 输出配置
    • output_folder:输出文件夹(默认output
    • out:视频写入对象(cv2.VideoWriter实例,仅当save_video=True时创建)
  • 深度学习模型
    • model:YOLOv8 模型实例
    • ship_class_id:船舶类别 ID(COCO 数据集中为8
  • 检测参数
    • confidence_threshold:置信度阈值(过滤低置信度检测结果)
    • nms_threshold:非极大值抑制阈值(过滤重叠检测框)
  • 跟踪参数
    • trajectories:存储轨迹的字典(键为船舶 ID,值为轨迹信息)
    • max_disappeared_frames:允许目标消失的最大帧数(超过则删除轨迹)
    • max_distance:轨迹匹配的最大距离(像素)
    • min_trajectory_points:计算航向所需的最小轨迹点数
  • 透视变换
    • perspective_transform:透视变换矩阵(校准后生成)
    • warped_width/warped_height:鸟瞰图尺寸(默认 800×800)

3.2 方法列表

3.2.1 calibrate_perspective()
  • 功能:通过鼠标点击选择 4 个点,校准透视变换矩阵,生成鸟瞰图。
  • 操作说明
    1. 显示视频第一帧,按顺序点击左上、右上、右下、左下四个点,形成矩形区域。
    2. q键退出校准。
  • 返回值boolTrue为校准成功,False为取消或失败)
3.2.2 detect_ships(frame)
  • 功能:使用 YOLOv8 模型检测图像中的船舶。
  • 输入frame(BGR 格式图像)
  • 处理流程
    1. 调用 YOLOv8 模型进行预测,指定类别为船舶(ID=8)。
    2. 过滤低于置信度阈值的检测结果。
    3. 应用非极大值抑制(NMS)消除重叠框。
  • 返回值:船舶检测结果列表(每个元素为字典,包含bboxcenterconfidenceclass
3.2.3 calculate_heading(positions)
  • 功能:根据船舶轨迹点计算航向角(0-360 度,0 为正北,顺时针增加)。
  • 输入positions(轨迹点列表,每个点为(x, y)坐标)
  • 算法逻辑
    1. 选择最近的min_trajectory_points个点。
    2. 使用最小二乘法拟合直线。
    3. 计算直线角度并转换为航向角。
  • 返回值:航向角(浮点数,单位为度)或None(轨迹点不足时)
3.2.4 track_ships(detected_ships)
  • 功能:根据检测结果更新船舶轨迹。
  • 输入detected_shipsdetect_ships返回的船舶列表)
  • 算法逻辑
    1. 计算现有轨迹与新检测的匹配距离(欧氏距离),优先匹配近距离目标。
    2. 未匹配的轨迹:若连续消失超过max_disappeared_frames,则删除。
    3. 未匹配的检测:创建新轨迹,分配唯一 ID。
3.2.5 draw_results(frame, ships)
  • 功能:在图像上绘制检测框、轨迹、航向及统计信息,支持鸟瞰图显示。
  • 输入
    • frame:原始帧
    • ships:检测到的船舶列表
  • 输出:绘制后的结果图像(若show_warped=True,则为原始帧与鸟瞰图的横向拼接图)
3.2.6 save_trajectories()
  • 功能:将当前所有轨迹数据保存到文本文件,包含 ID、起始时间、轨迹点坐标及平均航向。
  • 存储路径output_folder/ship_trajectories_时间戳.txt
3.2.7 run()
  • 功能:运行跟踪主循环,处理视频流并实时显示结果。
  • 操作说明
    • q键退出程序。
    • s键保存当前轨迹数据。
  • 流程
    1. 调用calibrate_perspective()进行透视校准(可选)。
    2. 逐帧读取视频,检测、跟踪船舶,绘制结果。
    3. 释放资源并关闭窗口。

四、主程序入口

python

运行

if __name__ == "__main__":tracker = ShipTracker(video_source=0,       # 0为摄像头,或指定视频文件路径(如"ship_video.mp4")save_video=True,      # 启用视频录制show_warped=True,     # 显示鸟瞰图model_path="yolov8n.pt"  # YOLOv8模型路径)tracker.run()

五、使用说明

5.1 环境配置

  1. 安装依赖库:

    bash

    pip install opencv-python numpy ultralytics
    
  2. 下载 YOLOv8 模型(如yolov8n.pt),并指定正确路径。

5.2 透视校准操作

  1. 运行程序后,会弹出窗口提示选择 4 个点。
  2. 按顺序点击视频中的矩形区域四角(如水面区域),生成鸟瞰图。
  3. 校准完成后,右侧会显示鸟瞰图中的船舶轨迹。

5.3 输出文件

  • 视频文件:若save_video=True,生成output/ship_tracking_时间戳.avi
  • 轨迹文件:按s键生成output/ship_trajectories_时间戳.txt,包含各 ID 的坐标序列和航向信息。

六、参数调整建议

参数名作用调整场景
confidence_threshold过滤低置信度的船舶检测结果目标较小或环境复杂时调高
nms_threshold控制非极大值抑制的严格程度船舶密集时调低
max_disappeared_frames目标消失后保留轨迹的帧数船舶被遮挡时间较长时调大
max_distance轨迹匹配的最大允许距离船舶运动速度快时调大
min_trajectory_points计算航向所需的最小轨迹点数航向计算不稳定时调大

七、注意事项

  1. YOLOv8 模型需要一定计算资源,建议在 GPU 环境下运行以提高帧率。
  2. 透视校准的四点应选择实际场景中的矩形区域(如水面边界),以确保鸟瞰图坐标准确。
  3. 船舶航向计算基于轨迹拟合,需要足够的轨迹点才能保证准确性。
  4. 若视频帧率较低,可尝试降低warped_width或关闭show_warped以减少计算量。

完整代码

import cv2
import numpy as np
import time
import os
from datetime import datetime
from ultralytics import YOLOclass ShipTracker:def __init__(self, video_source=0, save_video=False, show_warped=True, model_path="D:/06_Python/20250321_Deep_Learning/yolov8n.pt"):"""初始化船舶跟踪器"""# 视频源设置self.video_source = video_sourceself.cap = cv2.VideoCapture(video_source)if not self.cap.isOpened():raise ValueError("无法打开视频源", video_source)# 获取视频的宽度、高度和帧率self.frame_width = int(self.cap.get(3))self.frame_height = int(self.cap.get(4))self.fps = self.cap.get(cv2.CAP_PROP_FPS)# 输出设置self.save_video = save_videoself.output_folder = "output"self.show_warped = show_warped# 创建输出文件夹if not os.path.exists(self.output_folder):os.makedirs(self.output_folder)# 加载YOLOv8模型self.model = YOLO(model_path)self.ship_class_id = 8  # COCO数据集中船的类别ID# 船舶检测参数self.confidence_threshold = 0.5self.nms_threshold = 0.4# 轨迹存储self.trajectories = {}  # 存储每艘船的轨迹self.next_ship_id = 1  # 下一个可用的船舶IDself.max_disappeared_frames = 15  # 最大消失帧数self.max_distance = 150  # 最大匹配距离self.min_trajectory_points = 5  # 计算航向所需的最小轨迹点# 透视变换参数self.perspective_transform = Noneself.warped_width = 800self.warped_height = 800# 录制设置self.out = Noneif save_video:timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")output_path = os.path.join(self.output_folder, f"ship_tracking_{timestamp}.avi")fourcc = cv2.VideoWriter_fourcc(*'XVID')self.out = cv2.VideoWriter(output_path, fourcc, self.fps, (self.frame_width, self.frame_height))def calibrate_perspective(self):"""校准透视变换,创建鸟瞰图"""print("请在图像中选择4个点,形成一个矩形区域,用于透视变换")print("按顺序点击:左上、右上、右下、左下")# 读取一帧用于选择点ret, frame = self.cap.read()if not ret:print("无法读取视频帧")return False# 创建窗口并设置鼠标回调cv2.namedWindow("选择透视变换点 (按 'q' 退出)")points = []def click_event(event, x, y, flags, param):if event == cv2.EVENT_LBUTTONDOWN:points.append((x, y))cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)cv2.imshow("选择透视变换点 (按 'q' 退出)", frame)cv2.setMouseCallback("选择透视变换点 (按 'q' 退出)", click_event)# 显示图像并等待点击cv2.imshow("选择透视变换点 (按 'q' 退出)", frame)while len(points) < 4:key = cv2.waitKey(1) & 0xFFif key == ord('q'):cv2.destroyAllWindows()return Falsecv2.destroyAllWindows()# 定义目标矩形src = np.float32(points)dst = np.float32([[0, 0],[self.warped_width, 0],[self.warped_width, self.warped_height],[0, self.warped_height]])# 计算透视变换矩阵self.perspective_transform = cv2.getPerspectiveTransform(src, dst)return Truedef detect_ships(self, frame):"""使用YOLOv8检测图像中的船舶"""# 运行模型预测results = self.model(frame, classes=self.ship_class_id, conf=self.confidence_threshold, iou=self.nms_threshold)# 处理检测结果ships = []for result in results:boxes = result.boxes.cpu().numpy()for box in boxes:x1, y1, x2, y2 = box.xyxy[0].astype(int)conf = box.conf[0]cls = int(box.cls[0])# 计算边界框中心点和宽高w, h = x2 - x1, y2 - y1center = (int(x1 + w/2), int(y1 + h/2))ships.append({'bbox': (x1, y1, w, h),'center': center,'confidence': conf,'class': cls})return shipsdef calculate_heading(self, positions):"""根据轨迹点计算船舶航向"""if len(positions) < self.min_trajectory_points:return None# 选择最近的几个点recent_points = positions[-self.min_trajectory_points:]# 拟合直线x = np.array([p[0] for p in recent_points])y = np.array([p[1] for p in recent_points])# 计算直线拟合A = np.vstack([x, np.ones(len(x))]).Tm, c = np.linalg.lstsq(A, y, rcond=None)[0]# 计算角度(弧度)angle = np.arctan2(1, m)  # y轴向下为正# 转换为角度(0-360度,0度为正北,顺时针增加)heading = (np.degrees(angle) + 90) % 360return headingdef track_ships(self, detected_ships):"""跟踪检测到的船舶"""# 计算当前检测点与现有轨迹的距离unmatched_tracks = list(self.trajectories.keys())unmatched_detections = list(range(len(detected_ships)))matches = []# 计算所有可能的匹配for track_id in self.trajectories:trajectory = self.trajectories[track_id]last_position = trajectory['positions'][-1]min_distance = float('inf')min_index = -1for i, ship in enumerate(detected_ships):if i in unmatched_detections:distance = np.sqrt((last_position[0] - ship['center'][0])**2 + (last_position[1] - ship['center'][1])**2)if distance < min_distance and distance < self.max_distance:min_distance = distancemin_index = i# 如果找到匹配if min_index != -1:matches.append((track_id, min_index, min_distance))# 按距离排序,优先处理距离近的匹配matches.sort(key=lambda x: x[2])# 应用匹配for match in matches:track_id, detection_index, _ = matchif track_id in unmatched_tracks and detection_index in unmatched_detections:# 更新轨迹self.trajectories[track_id]['positions'].append(detected_ships[detection_index]['center'])self.trajectories[track_id]['last_seen'] = 0self.trajectories[track_id]['bbox'] = detected_ships[detection_index]['bbox']self.trajectories[track_id]['confidence'] = detected_ships[detection_index]['confidence']# 从待匹配列表中移除unmatched_tracks.remove(track_id)unmatched_detections.remove(detection_index)# 处理未匹配的轨迹for track_id in unmatched_tracks:self.trajectories[track_id]['last_seen'] += 1if self.trajectories[track_id]['last_seen'] > self.max_disappeared_frames:del self.trajectories[track_id]# 处理未匹配的检测结果for detection_index in unmatched_detections:# 创建新轨迹self.trajectories[self.next_ship_id] = {'positions': [detected_ships[detection_index]['center']],'last_seen': 0,'bbox': detected_ships[detection_index]['bbox'],'confidence': detected_ships[detection_index]['confidence'],'start_time': time.time()}self.next_ship_id += 1def draw_results(self, frame, ships):"""在图像上绘制检测和跟踪结果"""output = frame.copy()# 绘制检测到的船舶for ship in ships:x, y, w, h = ship['bbox']cv2.rectangle(output, (x, y), (x + w, y + h), (0, 255, 0), 2)cv2.circle(output, ship['center'], 5, (0, 0, 255), -1)cv2.putText(output, f"Conf: {ship['confidence']:.2f}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)# 绘制轨迹和航向for track_id, trajectory in self.trajectories.items():positions = trajectory['positions']# 绘制轨迹线for i in range(1, len(positions)):cv2.line(output, positions[i-1], positions[i], (255, 0, 0), 2)# 绘制轨迹点for pos in positions:cv2.circle(output, pos, 3, (255, 0, 0), -1)# 计算并绘制航向heading = self.calculate_heading(positions)if heading is not None:center = positions[-1]# 计算航向线终点heading_rad = np.radians(heading - 90)  # 转换为OpenCV坐标系length = 50end_point = (int(center[0] + length * np.cos(heading_rad)),int(center[1] + length * np.sin(heading_rad)))# 绘制航向线cv2.arrowedLine(output, center, end_point, (0, 255, 255), 3, tipLength=0.3)# 显示航向角度cv2.putText(output, f"Heading: {heading:.1f}°", (center[0] + 10, center[1] - 40),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)# 绘制ID和轨迹长度if len(positions) > 0:last_pos = positions[-1]cv2.putText(output, f"ID: {track_id}", (last_pos[0] + 10, last_pos[1] - 20),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)cv2.putText(output, f"Points: {len(positions)}", (last_pos[0] + 10, last_pos[1]),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)# 显示统计信息cv2.putText(output, f"Ships: {len(self.trajectories)}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)cv2.putText(output, f"FPS: {int(self.fps)}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)# 创建结果显示窗口if self.show_warped and self.perspective_transform is not None:# 创建鸟瞰图warped = cv2.warpPerspective(output, self.perspective_transform, (self.warped_width, self.warped_height))# 合并显示# 调整图像大小使高度一致if output.shape[0] != warped.shape[0]:scale = output.shape[0] / warped.shape[0]new_width = int(warped.shape[1] * scale)warped = cv2.resize(warped, (new_width, output.shape[0]))combined = np.hstack((output, warped))return combinedreturn outputdef save_trajectories(self):"""保存轨迹数据到文件"""timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")output_path = os.path.join(self.output_folder, f"ship_trajectories_{timestamp}.txt")with open(output_path, 'w') as f:f.write("Ship Trajectories\n")f.write(f"Recorded on: {datetime.now()}\n\n")for track_id, trajectory in self.trajectories.items():f.write(f"Ship ID: {track_id}\n")f.write(f"Start Time: {time.ctime(trajectory['start_time'])}\n")f.write(f"Duration: {time.time() - trajectory['start_time']:.2f} seconds\n")f.write(f"Trajectory Points: {len(trajectory['positions'])}\n")# 计算平均航向heading = self.calculate_heading(trajectory['positions'])if heading is not None:f.write(f"Average Heading: {heading:.1f}°\n")f.write("Positions:\n")for pos in trajectory['positions']:f.write(f"  ({pos[0]}, {pos[1]})\n")f.write("\n")print(f"轨迹数据已保存到: {output_path}")def run(self):"""运行船舶跟踪系统"""# 首先进行透视校准if not self.calibrate_perspective():print("透视校准失败,使用原始视角")print("开始船舶跟踪...")print("按 'q' 退出,按 's' 保存轨迹数据")frame_count = 0start_time = time.time()while True:ret, frame = self.cap.read()if not ret:break# 计算实际帧率frame_count += 1if frame_count % 10 == 0:elapsed_time = time.time() - start_timeself.fps = frame_count / elapsed_time# 检测船舶ships = self.detect_ships(frame)# 跟踪船舶self.track_ships(ships)# 绘制结果result = self.draw_results(frame, ships)# 保存视频if self.save_video:self.out.write(result)# 显示结果cv2.imshow("船舶轨迹跟踪系统 (按 'q' 退出,按 's' 保存轨迹)", result)# 按键处理key = cv2.waitKey(1) & 0xFFif key == ord('q'):breakelif key == ord('s'):self.save_trajectories()# 释放资源self.cap.release()if self.out:self.out.release()cv2.destroyAllWindows()print("船舶跟踪系统已关闭")# 主程序入口
if __name__ == "__main__":# 创建船舶跟踪器实例tracker = ShipTracker(video_source=0,  # 0表示默认摄像头,也可以指定视频文件路径save_video=True,  # 是否保存视频show_warped=True,  # 是否显示鸟瞰图model_path="D:/06_Python/20250321_Deep_Learning/yolov8n.pt"  # YOLOv8模型路径)# 运行跟踪器tracker.run()    

http://www.xdnf.cn/news/505675.html

相关文章:

  • 编译原理7~9
  • 【Element UI】表单及其验证规则详细
  • python运算符
  • python训练营打卡第26天
  • Go语言 Gin框架 使用指南
  • js中不同循环的使用以及结束循环方法
  • 两个电机由同一个控制器控制,其中一个电机发生堵转时,另一个电机的电流会变大,是发生了倒灌现象吗?电流倒灌产生的机理是什么?
  • Gartner《How to Leverage Lakehouse Design in Your DataStrategy》学习心得
  • SAP HCM 0008数据存储逻辑
  • 《棒球万事通》球类运动有哪些项目·棒球1号位
  • c++ 运算符重载
  • 16 C 语言布尔类型与 sizeof 运算符详解:布尔类型的三种声明方式、执行时间、赋值规则
  • qt6 c++操作qtableview和yaml
  • 使用 CodeBuddy 开发一款富交互的屏幕录制与注释分享工具开发纪实
  • C语言查漏补缺
  • Codeforces Round 1024 (Div.2)
  • 【C/C++】C++返回值优化:RVO与NRVO全解析
  • 安全性(三):信息安全的五要素及其含义
  • Python-92:最大乘积区间问题
  • 从AI系统到伦理平台:技术治理的开放转向
  • docker部署第一个Go项目
  • 语音转文字并进行中英文翻译
  • 【JavaScript】 js 基础知识强化复习
  • 2025系统架构师---选择题知识点(押题)
  • JavaScript基础-作用域链
  • vue3: amap using typescript
  • 【2025 技术指南】如何创建和配置国际版 Apple ID
  • DeepSeek 赋能社会科学:解锁研究新范式
  • 第三十四节:特征检测与描述-SIFT/SURF 特征 (专利算法)
  • JavaScript基础-对象的相关概念