当前位置: 首页 > ai >正文

YOLO-World 部署踩坑记录

yolo-word 根据提示词检测物体

项目地址:

AILab-CVC/YOLO-World:[CVPR 2024] 实时开放词汇对象检测

1 环境搭建

由于YOLO-World 需要使用mmcv库,这个库对依赖库要求比较严格,所以环境搭建耗费不少时间。

mmcv 要求是 2.1 以下,我安装的版本是2.01

mmcv 下载地址:

Installation — mmcv 2.2.0 文档

可以自行根据cuda 和torch 版本选择对应的mmcv 版本

注意的坑:

mmcv 版本 不可过高,不然和mmdet不兼容

cuda 版本和torch版本不可过高,不然对应的mmcv版本没有2.1以下。

搭建过程中,我使用的是cuda11.8 加 torch 2.01 成功搭建环境。

transformers 库同样不能太高版本,4.3x 即可

其他库正常安装就好。

2 项目代码踩坑。

我使用项目demo里推理代码时,会报许多函数错误,传参,接受参数的问题,修改起来比较麻烦,索性我会贴在博客最后。

主要出问题的是

YOLO-World-master\yolo_world\models\detectors路径下的 yolo_world.py 文件

3 模型问题

项目官方并没有提供关于文本处理的模型的下载地址,需要自己去huggingface下载,下载地址如下:huggingface.co/openai/clip-vit-base-patch32/tree/main

同样另一个模型,在配置py文件里名为

yolo_world_l_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_cc3mlite_train-ca93cd1f.pth

直接搜索是找不到这个模型的下载地址的,这个模型其实是官方项目里提供的yolo-word检测模型

 配置文件里如果选择coco数据集,则需要提供coco数据集标注数据

instances_train2017.json

这个网上资源很多,可以自行下载。

4 效果展示

提示词:bus

关键词:preson

5   修改后的完整 yolo_world.py 

# Copyright (c) Tencent Inc. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):"""Implementation of YOLOW Series"""def __init__(self,*args,mm_neck: bool = False,num_train_classes=80,num_test_classes=80,**kwargs) -> None:self.mm_neck = mm_neckself.num_train_classes = num_train_classesself.num_test_classes = num_test_classessuper().__init__(*args, **kwargs)def loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Union[dict, list]:"""Calculate losses from a batch of inputs and data samples."""self.bbox_head.num_classes = self.num_train_classesimg_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)losses = self.bbox_head.loss(img_feats, txt_feats, txt_masks,batch_data_samples)return lossesdef predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing."""img_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)# self.bbox_head.num_classes = self.num_test_classesself.bbox_head.num_classes = txt_feats[0].shape[0]results_list = self.bbox_head.predict(img_feats,txt_feats,txt_masks,batch_data_samples,rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samplesdef reparameterize(self, texts: List[List[str]]) -> None:# encode text embeddings into the detectorself.texts = textsself.text_feats, _ = self.backbone.forward_text(texts)def _forward(self,batch_inputs: Tensor,batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing."""img_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)results = self.bbox_head.forward(img_feats, txt_feats, txt_masks)return resultsdef extract_feat(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:"""Extract features."""txt_feats = Noneif batch_data_samples is None:texts = self.textstxt_feats = self.text_featselif isinstance(batch_data_samples,dict) and 'texts' in batch_data_samples:texts = batch_data_samples['texts']elif isinstance(batch_data_samples, list) and hasattr(batch_data_samples[0], 'texts'):texts = [data_sample.texts for data_sample in batch_data_samples]elif hasattr(self, 'text_feats'):texts = self.textstxt_feats = self.text_featselse:raise TypeError('batch_data_samples should be dict or list.')if txt_feats is not None:# forward image onlyimg_feats = self.backbone.forward_image(batch_inputs)else:img_feats, (txt_feats,txt_masks) = self.backbone(batch_inputs, texts)if self.with_neck:if self.mm_neck:img_feats = self.neck(img_feats, txt_feats)else:img_feats = self.neck(img_feats)return img_feats, txt_feats, txt_masks@MODELS.register_module()
class SimpleYOLOWorldDetector(YOLODetector):"""Implementation of YOLO World Series"""def __init__(self,*args,mm_neck: bool = False,num_train_classes=80,num_test_classes=80,prompt_dim=512,num_prompts=80,embedding_path='',reparameterized=False,freeze_prompt=False,use_mlp_adapter=False,**kwargs) -> None:self.mm_neck = mm_neckself.num_training_classes = num_train_classesself.num_test_classes = num_test_classesself.prompt_dim = prompt_dimself.num_prompts = num_promptsself.reparameterized = reparameterizedself.freeze_prompt = freeze_promptself.use_mlp_adapter = use_mlp_adaptersuper().__init__(*args, **kwargs)if not self.reparameterized:if len(embedding_path) > 0:import numpy as npself.embeddings = torch.nn.Parameter(torch.from_numpy(np.load(embedding_path)).float())else:# random initembeddings = nn.functional.normalize(torch.randn((num_prompts, prompt_dim)),dim=-1)self.embeddings = nn.Parameter(embeddings)if self.freeze_prompt:self.embeddings.requires_grad = Falseelse:self.embeddings.requires_grad = Trueif use_mlp_adapter:self.adapter = nn.Sequential(nn.Linear(prompt_dim, prompt_dim * 2), nn.ReLU(True),nn.Linear(prompt_dim * 2, prompt_dim))else:self.adapter = Nonedef loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Union[dict, list]:"""Calculate losses from a batch of inputs and data samples."""self.bbox_head.num_classes = self.num_training_classesimg_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)if self.reparameterized:losses = self.bbox_head.loss(img_feats, batch_data_samples)else:losses = self.bbox_head.loss(img_feats, txt_feats,batch_data_samples)return lossesdef predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing."""img_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)self.bbox_head.num_classes = self.num_test_classesif self.reparameterized:results_list = self.bbox_head.predict(img_feats,batch_data_samples,rescale=rescale)else:results_list = self.bbox_head.predict(img_feats,txt_feats,batch_data_samples,rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samplesdef _forward(self,batch_inputs: Tensor,batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing."""img_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)if self.reparameterized:results = self.bbox_head.forward(img_feats)else:results = self.bbox_head.forward(img_feats, txt_feats)return resultsdef extract_feat(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:"""Extract features."""# only image featuresimg_feats, _ = self.backbone(batch_inputs, None)if not self.reparameterized:# use embeddingstxt_feats = self.embeddings[None]if self.adapter is not None:txt_feats = self.adapter(txt_feats) + txt_featstxt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2)txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1)else:txt_feats = Noneif self.with_neck:if self.mm_neck:img_feats = self.neck(img_feats, txt_feats)else:img_feats = self.neck(img_feats)return img_feats, txt_feats

http://www.xdnf.cn/news/13343.html

相关文章:

  • Linux611 libvirtb ;FTP vsftpd.conf部分配置文件
  • React 元素的生命周期
  • 从硬件视角审视Web3安全:CertiK CTO主持Proof of Talk圆桌论坛
  • GO 入门小项目-博客-结合Gin Gorm
  • 【面板数据】中国与世界各国农产品进出口贸易数据(2015-2024年)
  • 国内外数据要素标准有何不同?
  • K8S项目需求分析
  • 火山引擎发布豆包大模型 1.6 与视频生成模型 Seedance 1.0 pro
  • ABP vNext + Spark on Hadoop:实时流处理与微服务融合
  • 嵌入式学习笔记 - C语言访问地址的方式,以及指针的进一步理解
  • JMeter 处理 UTF-16 转 UTF-8 乱码问题解决方案(deepseek)
  • AnythingLLM配置Milvus后,上传文档提示向量数据库标识符错误的解决办法
  • 鹰盾Win播放器作为专业的视频安全解决方案,除了硬件翻录外还有什么呢?
  • 微信小程序分享带参数地址
  • UFS-Ver3.1-第八章
  • 6.11 打卡
  • 对话机器人预测场景与 Prompt / 模型选择指南
  • 探究:什么是扁平化组织?有什么益处?
  • gitlab相关操作
  • 实战案例-FPGA的JESD204调试问题解析
  • 青少年编程与数学 01-011 系统软件简介 13 Microsoft SQL Server数据库
  • 关于使用WebSocket时无法使用@Autowired 注入的问题
  • CompletableFuture浅谈
  • Efficient Attention 理解
  • 美团完整面经
  • Matlab解决无法读取路径中的空格
  • matlab分布式电源微电网潮流
  • uni-app 自定义路由封装模块详解(附源码逐行解读)
  • FEMFAT许可使用数据分析工具介绍
  • MySQL 主从复制与一主多从架构实战详解