YOLO-World 部署踩坑记录
yolo-word 根据提示词检测物体
项目地址:
AILab-CVC/YOLO-World:[CVPR 2024] 实时开放词汇对象检测
1 环境搭建
由于YOLO-World 需要使用mmcv库,这个库对依赖库要求比较严格,所以环境搭建耗费不少时间。
mmcv 要求是 2.1 以下,我安装的版本是2.01
mmcv 下载地址:
Installation — mmcv 2.2.0 文档
可以自行根据cuda 和torch 版本选择对应的mmcv 版本
注意的坑:
mmcv 版本 不可过高,不然和mmdet不兼容
cuda 版本和torch版本不可过高,不然对应的mmcv版本没有2.1以下。
搭建过程中,我使用的是cuda11.8 加 torch 2.01 成功搭建环境。
transformers 库同样不能太高版本,4.3x 即可
其他库正常安装就好。
2 项目代码踩坑。
我使用项目demo里推理代码时,会报许多函数错误,传参,接受参数的问题,修改起来比较麻烦,索性我会贴在博客最后。
主要出问题的是
YOLO-World-master\yolo_world\models\detectors路径下的 yolo_world.py 文件
3 模型问题
项目官方并没有提供关于文本处理的模型的下载地址,需要自己去huggingface下载,下载地址如下:huggingface.co/openai/clip-vit-base-patch32/tree/main
同样另一个模型,在配置py文件里名为
yolo_world_l_clip_t2i_bn_2e-3adamw_32xb16-100e_obj365v1_goldg_cc3mlite_train-ca93cd1f.pth
直接搜索是找不到这个模型的下载地址的,这个模型其实是官方项目里提供的yolo-word检测模型
配置文件里如果选择coco数据集,则需要提供coco数据集标注数据
instances_train2017.json
这个网上资源很多,可以自行下载。
4 效果展示
提示词:bus
关键词:preson
5 修改后的完整 yolo_world.py
# Copyright (c) Tencent Inc. All rights reserved.
from typing import List, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from mmdet.structures import OptSampleList, SampleList
from mmyolo.models.detectors import YOLODetector
from mmyolo.registry import MODELS@MODELS.register_module()
class YOLOWorldDetector(YOLODetector):"""Implementation of YOLOW Series"""def __init__(self,*args,mm_neck: bool = False,num_train_classes=80,num_test_classes=80,**kwargs) -> None:self.mm_neck = mm_neckself.num_train_classes = num_train_classesself.num_test_classes = num_test_classessuper().__init__(*args, **kwargs)def loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Union[dict, list]:"""Calculate losses from a batch of inputs and data samples."""self.bbox_head.num_classes = self.num_train_classesimg_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)losses = self.bbox_head.loss(img_feats, txt_feats, txt_masks,batch_data_samples)return lossesdef predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing."""img_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)# self.bbox_head.num_classes = self.num_test_classesself.bbox_head.num_classes = txt_feats[0].shape[0]results_list = self.bbox_head.predict(img_feats,txt_feats,txt_masks,batch_data_samples,rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samplesdef reparameterize(self, texts: List[List[str]]) -> None:# encode text embeddings into the detectorself.texts = textsself.text_feats, _ = self.backbone.forward_text(texts)def _forward(self,batch_inputs: Tensor,batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing."""img_feats, txt_feats, txt_masks = self.extract_feat(batch_inputs, batch_data_samples)results = self.bbox_head.forward(img_feats, txt_feats, txt_masks)return resultsdef extract_feat(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:"""Extract features."""txt_feats = Noneif batch_data_samples is None:texts = self.textstxt_feats = self.text_featselif isinstance(batch_data_samples,dict) and 'texts' in batch_data_samples:texts = batch_data_samples['texts']elif isinstance(batch_data_samples, list) and hasattr(batch_data_samples[0], 'texts'):texts = [data_sample.texts for data_sample in batch_data_samples]elif hasattr(self, 'text_feats'):texts = self.textstxt_feats = self.text_featselse:raise TypeError('batch_data_samples should be dict or list.')if txt_feats is not None:# forward image onlyimg_feats = self.backbone.forward_image(batch_inputs)else:img_feats, (txt_feats,txt_masks) = self.backbone(batch_inputs, texts)if self.with_neck:if self.mm_neck:img_feats = self.neck(img_feats, txt_feats)else:img_feats = self.neck(img_feats)return img_feats, txt_feats, txt_masks@MODELS.register_module()
class SimpleYOLOWorldDetector(YOLODetector):"""Implementation of YOLO World Series"""def __init__(self,*args,mm_neck: bool = False,num_train_classes=80,num_test_classes=80,prompt_dim=512,num_prompts=80,embedding_path='',reparameterized=False,freeze_prompt=False,use_mlp_adapter=False,**kwargs) -> None:self.mm_neck = mm_neckself.num_training_classes = num_train_classesself.num_test_classes = num_test_classesself.prompt_dim = prompt_dimself.num_prompts = num_promptsself.reparameterized = reparameterizedself.freeze_prompt = freeze_promptself.use_mlp_adapter = use_mlp_adaptersuper().__init__(*args, **kwargs)if not self.reparameterized:if len(embedding_path) > 0:import numpy as npself.embeddings = torch.nn.Parameter(torch.from_numpy(np.load(embedding_path)).float())else:# random initembeddings = nn.functional.normalize(torch.randn((num_prompts, prompt_dim)),dim=-1)self.embeddings = nn.Parameter(embeddings)if self.freeze_prompt:self.embeddings.requires_grad = Falseelse:self.embeddings.requires_grad = Trueif use_mlp_adapter:self.adapter = nn.Sequential(nn.Linear(prompt_dim, prompt_dim * 2), nn.ReLU(True),nn.Linear(prompt_dim * 2, prompt_dim))else:self.adapter = Nonedef loss(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Union[dict, list]:"""Calculate losses from a batch of inputs and data samples."""self.bbox_head.num_classes = self.num_training_classesimg_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)if self.reparameterized:losses = self.bbox_head.loss(img_feats, batch_data_samples)else:losses = self.bbox_head.loss(img_feats, txt_feats,batch_data_samples)return lossesdef predict(self,batch_inputs: Tensor,batch_data_samples: SampleList,rescale: bool = True) -> SampleList:"""Predict results from a batch of inputs and data samples with post-processing."""img_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)self.bbox_head.num_classes = self.num_test_classesif self.reparameterized:results_list = self.bbox_head.predict(img_feats,batch_data_samples,rescale=rescale)else:results_list = self.bbox_head.predict(img_feats,txt_feats,batch_data_samples,rescale=rescale)batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list)return batch_data_samplesdef _forward(self,batch_inputs: Tensor,batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:"""Network forward process. Usually includes backbone, neck and headforward without any post-processing."""img_feats, txt_feats = self.extract_feat(batch_inputs,batch_data_samples)if self.reparameterized:results = self.bbox_head.forward(img_feats)else:results = self.bbox_head.forward(img_feats, txt_feats)return resultsdef extract_feat(self, batch_inputs: Tensor,batch_data_samples: SampleList) -> Tuple[Tuple[Tensor], Tensor]:"""Extract features."""# only image featuresimg_feats, _ = self.backbone(batch_inputs, None)if not self.reparameterized:# use embeddingstxt_feats = self.embeddings[None]if self.adapter is not None:txt_feats = self.adapter(txt_feats) + txt_featstxt_feats = nn.functional.normalize(txt_feats, dim=-1, p=2)txt_feats = txt_feats.repeat(img_feats[0].shape[0], 1, 1)else:txt_feats = Noneif self.with_neck:if self.mm_neck:img_feats = self.neck(img_feats, txt_feats)else:img_feats = self.neck(img_feats)return img_feats, txt_feats