Semantic-SAM: Segment and Recognize Anything at Any Granularity
目录
摘要
Abstract
Semantic-SAM
研究背景
模型框架
模型基础
多粒度提示机制
解耦分类
统一提示表示
多对多训练策略
实验
代码
总结
摘要
Semantic-SAM 是一种新型的图像分割模型,基于 Mask DINO 框架改进而来,旨在解决先前模型SAM 在语义感知和多粒度分割方面的不足。它通过联合训练多个数据集,实现了对图像中物体的高质量分割,并能够为分割结果赋予准确的语义标签。此外,Semantic-SAM 引入了多选择学习方案,支持丰富的粒度控制,使其在通用分割、细粒度分割和交互式分割任务中表现出色。该模型不仅提升了分割精度,还拓展了图像分割的应用场景,为计算机视觉领域带来了新的突破。
Abstract
Semantic-SAM is a novel image segmentation model, improved from the Mask DINO framework, designed to address the shortcomings of previous models SAM in semantic perception and multi-granularity segmentation. By jointly training on multiple datasets, it achieves high-quality segmentation of objects in images and can assign accurate semantic labels to the segmentation results. Additionally, Semantic-SAM introduces a multi-choice learning scheme, supporting rich granularity control, which enables it to excel in general segmentation, fine-grained segmentation, and interactive segmentation tasks. The model not only enhances segmentation accuracy but also expands the application scenarios of image segmentation, bringing new breakthroughs to the field of computer vision.
Semantic-SAM
项目地址:https://github.com/UX-Decoder/Semantic-SAM
研究背景
问题定义:
- 现有分割模型,如:Meta的SAM,缺乏语义标签和多粒度分割能力;
- 数据集局限:通用分割数据,如:COCO,仅标注物体级语义,部件数据集规模小,而SA-1B虽含多粒度掩码但无语义标签。
目标:构建统一模型
- 语义感知:为分割结果提供物体、部件级语义标签;
- 粒度丰富:支持从物体到部件的任意粒度分割;
- 多功能支持:覆盖全景、实例、部件分割、交互式分割等任务。
模型框架
模型基础
Semantic-SAM 以 Mask DINO 为骨干网络,后者是 DETR框架的扩展,支持端到端的检测与分割任务。其架构包含:
- 视觉编码器:采用预训练的 Swin Transformer 提取多尺度图像特征;
- 解码器:核心改进部分,基于可变形注意力机制,支持通用分割与交互式分割的双重任务。
多粒度提示机制
为生成多级分割结果,模型引入 K 级提示向量,将单点输入扩展为多粒度查询:
提示生成:
- 用户点击点 (x, y) 转换为锚框 (x, y, w, h);
- 每个点生成 6 个内容提示向量,每个向量绑定可学习的 level embedding,表示不同粒度级别。
解码输出:
- 提示向量输入解码器,输出 6 组掩码-语义对 (m_i, c_i):
m_i:分割掩码(二值图)、c_i:语义标签(物体类或部件类)。
# 伪代码示例:提示处理流程
prompts = []
for i in range(6):query = point_embedding(x, y) + level_embedding[i] + type_embedding("point")prompts.append(query)
masks, labels = deformable_decoder(prompts, image_features)
解耦分类
为实现物体与部件级语义识别,模型采用双分类器解耦策略:
- 物体分类器:识别整体类别(如“狗”、“汽车”);
- 部件分类器:识别局部结构(如“车轮”、“头部”);
- 共享文本编码器(如 CLIP):统一编码语义概念,但两类分类器独立训练,支持跨类别泛化(如从“狗头”学习“狮头”)。
训练灵活性:
- 仅物体标注数据(如 COCO):仅启用物体分类损失;
- 部件标注数据(如 Pascal Part):同时启用物体与部件损失。
统一提示表示
为简化架构,模型将点提示与框提示统一为锚框格式:
- 点提示:(x, y) → (x, y, w, h);
- 框提示:直接使用 (x, y, w, h)
多对多训练策略
针对 SA-1B 的多粒度标注特性,模型提出 Many-to-Many 匹配监督:
匹配机制:
- 单个点击点对应多个真实掩码(不同粒度);
- 6 个预测掩码与所有相关真值通过匈牙利算法进行多对多匹配,确保每个预测掩码均受监督。
损失函数:
- 掩码损失:Dice 损失 + Focal 损失;
- 语义损失:解耦的物体/部件分类损失(交叉熵)。
实验
全景分割对比:
部分分割对比:
代码
训练代码:
# ------------------------------------------------------------------------
# Copyright (c) MicroSoft, Inc. and its affiliates.
# Modified from OpenSeed https://github.com/IDEA-Research/OpenSeed by Feng Li (fliay@connect.ust.hk).
# ------------------------------------------------------------------------
"""
Semantic-SAM training and inference script. based on MaskDINO and OpenSeed.
"""
try:from shapely.errors import ShapelyDeprecationWarningimport warningswarnings.filterwarnings('ignore', category=ShapelyDeprecationWarning)
except:passimport copy
import itertools
import logging
import os
import timefrom typing import Any, Dict, List, Setimport torchimport detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg, CfgNode
from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler
from detectron2.utils.logger import setup_logger
from detectron2.config import LazyConfig, instantiate# dataloader and evaluator
from datasets import (build_train_dataloader,build_evaluator,build_eval_dataloader,
)
import random
from detectron2.engine import (DefaultTrainer,default_argument_parser,default_setup,hooks,launch,create_ddp_model,AMPTrainer,SimpleTrainer
)
import weakreffrom semantic_sam import build_model
from semantic_sam.BaseModel import BaseModellogger = logging.getLogger(__name__)
logging.basicConfig(level = logging.INFO)class Trainer(DefaultTrainer):"""Extension of the Trainer class adapted to MaskFormer."""def __init__(self, cfg):super(DefaultTrainer, self).__init__()logger = logging.getLogger("detectron2")if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2setup_logger()cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())# Assume these objects must be constructed in this order.model = self.build_model(cfg)optimizer = self.build_optimizer(cfg, model)data_loader = self.build_train_loader(cfg)model = create_ddp_model(model, broadcast_buffers=False)self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, data_loader, optimizer)self.scheduler = self.build_lr_scheduler(cfg, optimizer)# add model EMAkwargs = {'trainer': weakref.proxy(self),}# kwargs.update(model_ema.may_get_ema_checkpointer(cfg, model)) TODO: release ema training for large modelsself.checkpointer = DetectionCheckpointer(# Assume you want to save checkpoints together with logs/statisticsmodel,cfg['OUTPUT_DIR'],**kwargs,)self.start_iter = 0self.max_iter = cfg['SOLVER']['MAX_ITER']self.cfg = cfgself.register_hooks(self.build_hooks())# TODO: release model conversion checkpointer from DINO to MaskDINOself.checkpointer = DetectionCheckpointer(# Assume you want to save checkpoints together with logs/statisticsmodel,cfg['OUTPUT_DIR'],**kwargs,)# TODO: release GPU cluster submit scripts based on submitit for multi-node trainingdef build_hooks(self):"""Build a list of default hooks, including timing, evaluation,checkpointing, lr scheduling, precise BN, writing events.Returns:list[HookBase]:"""cfg = copy.deepcopy(self.cfg)cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBNret = [hooks.IterationTimer(),hooks.LRScheduler(),None,]# Do PreciseBN before checkpointer, because it updates the model and need to# be saved by checkpointer.# This is not always the best: if checkpointing has a different frequency,# some checkpoints may have more precise statistics than others.if comm.is_main_process():ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))def test_and_save_results():self._last_eval_results = self.test(self.cfg, self.model)return self._last_eval_results# Do evaluation after checkpointer, because then if it fails,# we can use the saved checkpoint to debug.ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))if comm.is_main_process():# Here the default print/log frequency of each writer is used.# run writers in the end, so that evaluation metrics are writtenret.append(hooks.PeriodicWriter(self.build_writers(), period=1))return ret@classmethoddef build_model(cls, cfg):"""Returns:torch.nn.Module:It now calls :func:`detectron2.modeling.build_model`.Overwrite it if you'd like a different model."""model = BaseModel(cfg, build_model(cfg)).cuda()logger = logging.getLogger(__name__)logger.info("Model:\n{}".format(model))return model@classmethoddef build_evaluator(cls, cfg, dataset_name, output_folder=None):return build_evaluator(cfg, dataset_name, output_folder=output_folder)@classmethoddef build_train_loader(cls, cfg):return build_train_dataloader(cfg, )@classmethoddef build_test_loader(cls, cfg, dataset_name):# import ipdb; ipdb.set_trace()loader = build_eval_dataloader(cfg, )return loader@classmethoddef build_lr_scheduler(cls, cfg, optimizer):"""It now calls :func:`detectron2.solver.build_lr_scheduler`.Overwrite it if you'd like a different scheduler."""return build_lr_scheduler(cfg, optimizer)@classmethoddef build_optimizer(cls, cfg, model):cfg_solver = cfg['SOLVER']weight_decay_norm = cfg_solver['WEIGHT_DECAY_NORM']weight_decay_embed = cfg_solver['WEIGHT_DECAY_EMBED']weight_decay_bias = cfg_solver.get('WEIGHT_DECAY_BIAS', 0.0)defaults = {}defaults["lr"] = cfg_solver['BASE_LR']defaults["weight_decay"] = cfg_solver['WEIGHT_DECAY']norm_module_types = (torch.nn.BatchNorm1d,torch.nn.BatchNorm2d,torch.nn.BatchNorm3d,torch.nn.SyncBatchNorm,# NaiveSyncBatchNorm inherits from BatchNorm2dtorch.nn.GroupNorm,torch.nn.InstanceNorm1d,torch.nn.InstanceNorm2d,torch.nn.InstanceNorm3d,torch.nn.LayerNorm,torch.nn.LocalResponseNorm,)lr_multiplier = cfg['SOLVER']['LR_MULTIPLIER']params: List[Dict[str, Any]] = []memo: Set[torch.nn.parameter.Parameter] = set()for module_name, module in model.named_modules():for module_param_name, value in module.named_parameters(recurse=False):if not value.requires_grad:continue# Avoid duplicating parametersif value in memo:continuememo.add(value)hyperparams = copy.copy(defaults)for key, lr_mul in lr_multiplier.items():if key in "{}.{}".format(module_name, module_param_name):hyperparams["lr"] = hyperparams["lr"] * lr_mulif comm.is_main_process():logger.info("Modify Learning rate of {}: {}".format("{}.{}".format(module_name, module_param_name), lr_mul))if ("relative_position_bias_table" in module_param_nameor "absolute_pos_embed" in module_param_name):hyperparams["weight_decay"] = 0.0if isinstance(module, norm_module_types):hyperparams["weight_decay"] = weight_decay_normif isinstance(module, torch.nn.Embedding):hyperparams["weight_decay"] = weight_decay_embedif "bias" in module_name:hyperparams["weight_decay"] = weight_decay_biasparams.append({"params": [value], **hyperparams})def maybe_add_full_model_gradient_clipping(optim):# detectron2 doesn't have full model gradient clipping nowclip_norm_val = cfg_solver['CLIP_GRADIENTS']['CLIP_VALUE']enable = (cfg_solver['CLIP_GRADIENTS']['ENABLED']and cfg_solver['CLIP_GRADIENTS']['CLIP_TYPE'] == "full_model"and clip_norm_val > 0.0)class FullModelGradientClippingOptimizer(optim):def step(self, closure=None):all_params = itertools.chain(*[x["params"] for x in self.param_groups])torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)super().step(closure=closure)return FullModelGradientClippingOptimizer if enable else optimoptimizer_type = cfg_solver['OPTIMIZER']if optimizer_type == "SGD":optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(params, cfg_solver['BASE_LR'], momentum=cfg_solver['MOMENTUM'])elif optimizer_type == "ADAMW":optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(params, cfg_solver['BASE_LR'])else:raise NotImplementedError(f"no optimizer type {optimizer_type}")return optimizer@staticmethoddef auto_scale_workers(cfg, num_workers: int):"""Returns:CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``."""old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZEif old_world_size == 0 or old_world_size == num_workers:return cfgcfg = copy.deepcopy(cfg)# frozen = cfg.is_frozen()# cfg.defrost()assert (cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0), "Invalid REFERENCE_WORLD_SIZE in config!"scale = num_workers / old_world_sizebs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scalemax_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariantlogger = logging.getLogger(__name__)logger.info(f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "f"max_iter={max_iter}, warmup={warmup_iter}.")return cfg@classmethoddef test(cls, cfg, model, evaluators=None):from utils.misc import hook_metadata, hook_switcher, hook_optfrom detectron2.utils.logger import log_every_n_secondsimport datetime# build dataloadedataloaders = cls.build_test_loader(cfg, dataset_name=None)dataset_names = cfg['DATASETS']['TEST']model = model.eval().cuda()model_without_ddp = modelif not type(model) == BaseModel:model_without_ddp = model.modulefor dataloader, dataset_name in zip(dataloaders, dataset_names):# build evaluatorevaluator = build_evaluator(cfg, dataset_name, cfg['OUTPUT_DIR'])evaluator.reset()with torch.no_grad():# setup taskif 'sam' in dataset_names:task = 'multi_granularity'else:task = 'interactive'hook_switcher(model_without_ddp, dataset_name)# setup timertotal = len(dataloader)num_warmup = min(5, total - 1)start_time = time.perf_counter()total_data_time = 0total_compute_time = 0total_eval_time = 0start_data_time = time.perf_counter()for idx, batch in enumerate(dataloader):total_data_time += time.perf_counter() - start_data_timeif idx == num_warmup:start_time = time.perf_counter()total_data_time = 0total_compute_time = 0total_eval_time = 0start_compute_time = time.perf_counter()# forwardwith torch.autocast(device_type='cuda', dtype=torch.float16):# import ipdb; ipdb.set_trace()outputs = model(batch, inference_task=task)total_compute_time += time.perf_counter() - start_compute_timestart_eval_time = time.perf_counter()evaluator.process(batch, outputs)total_eval_time += time.perf_counter() - start_eval_timeiters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)data_seconds_per_iter = total_data_time / iters_after_startcompute_seconds_per_iter = total_compute_time / iters_after_starteval_seconds_per_iter = total_eval_time / iters_after_starttotal_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_startif comm.is_main_process() and (idx >= num_warmup * 2 or compute_seconds_per_iter > 5):eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))log_every_n_seconds(logging.INFO,(f"Inference done {idx + 1}/{total}. "f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "f"Inference: {compute_seconds_per_iter:.4f} s/iter. "f"Eval: {eval_seconds_per_iter:.4f} s/iter. "f"Total: {total_seconds_per_iter:.4f} s/iter. "f"ETA={eta}"),n=5,)start_data_time = time.perf_counter()# evaluateresults = evaluator.evaluate()model = model.train().cuda()def setup(args):"""Create configs and perform basic setups."""cfg = get_cfg()cfg = LazyConfig.load(args.config_file)cfg = LazyConfig.apply_overrides(cfg, args.opts)# cfg.freeze()default_setup(cfg, args)setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="maskdino")return cfgdef main(args=None):cfg = setup(args)print("Command cfg:", cfg)if args.eval_only:model = Trainer.build_model(cfg)DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)res = Trainer.test(cfg, model)if cfg.TEST.AUG.ENABLED:res.update(Trainer.test_with_TTA(cfg, model))return restrainer = Trainer(cfg)if len(args.lang_weight)>0:# load language weight for semanticimport copyweight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS)trainer.cfg.MODEL.WEIGHTS = args.lang_weightprint("load original language language weight!!!!!!")trainer.resume_or_load(resume=args.resume)trainer.cfg.MODEL.WEIGHTS = weightprint("load pretrained model weight!!!!!!")trainer.resume_or_load(resume=args.resume)return trainer.train()if __name__ == "__main__":# main()parser = default_argument_parser()parser.add_argument('--eval_only', action='store_true')parser.add_argument('--EVAL_FLAG', type=int, default=1)parser.add_argument('--lang_weight', type=str, default='')args = parser.parse_args()port = random.randint(1000, 20000)args.dist_url = 'tcp://127.0.0.1:' + str(port)print("Command Line Args:", args)print("pwd:", os.getcwd())launch(main,args.num_gpus,num_machines=args.num_machines,machine_rank=args.machine_rank,dist_url=args.dist_url,args=(args,),)
demo.py
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------import gradio as gr
import torch
import argparse# from gradio import processing_utils
from semantic_sam.BaseModel import BaseModel
from semantic_sam import build_model
from utils.dist import init_distributed_mode
from utils.arguments import load_opt_from_config_file
from utils.constants import COCO_PANOPTIC_CLASSESfrom tasks import interactive_infer_image_idino_m2mdef parse_option():parser = argparse.ArgumentParser('SemanticSAM Demo', add_help=False)parser.add_argument('--conf_files', default="configs/semantic_sam_only_sa-1b_swinL.yaml", metavar="FILE", help='path to config file', )parser.add_argument('--ckpt', default="", metavar="FILE", help='path to ckpt', )args = parser.parse_args()return args'''
build args
'''
args = parse_option()cur_model = 'None''''
build model
'''model=None
model_size=None
ckpt=None
cfgs={'T':"configs/semantic_sam_only_sa-1b_swinT.yaml",'L':"configs/semantic_sam_only_sa-1b_swinL.yaml"}# audio = whisper.load_model("base")
sam_cfg=cfgs['L']opt = load_opt_from_config_file(sam_cfg)model_sam = BaseModel(opt, build_model(opt)).from_pretrained(args.ckpt).eval().cuda()@torch.no_grad()
def inference(image,text,text_part,text_thresh,*args, **kwargs):text_size, hole_scale, island_scale=640,100,100with torch.autocast(device_type='cuda', dtype=torch.float16):semantic=Falsemodel=model_sama,b= interactive_infer_image_idino_m2m(model, image,text,text_part,text_thresh,text_size,hole_scale,island_scale,semantic, *args, **kwargs)return a,bclass ImageMask(gr.components.Image):"""Sets: source="canvas", tool="sketch""""is_template = Truedef __init__(self, **kwargs):super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)def preprocess(self, x):return super().preprocess(x)'''
launch app
'''
title = "SEMANTIC-SAM: SEGMENT AND RECOGNIZE ANYTHING AT ANY GRANULARITY"article = "The Demo is Run on SEMANTIC SAM."from detectron2.data import MetadataCatalog
from utils.constants import COCO_PANOPTIC_CLASSES
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
all_classes = [name.replace('-other','').replace('-merged','') for name in COCO_PANOPTIC_CLASSES]
all_parts=['arm', 'beak', 'body', 'cap', 'door', 'ear', 'eye', 'foot', 'hair', 'hand', 'handlebar', 'head', 'headlight', 'horn', 'leg', 'license plate', 'mirror', 'mouth', 'muzzle', 'neck', 'nose', 'paw', 'plant', 'pot', 'saddle', 'tail', 'torso', 'wheel', 'window', 'wing']demo = gr.Blocks()
image=ImageMask(label="Click on Image (Please only click one point, or our model will take the center of all points as the clicked location. Remember to clear the click after each interaction, or we will take the center of the current click and previous ones as the clicked location.)",type="pil",brush_radius=15.0).style(height=512)
gallery_output=gr.Gallery(label="Image Gallery sorted by IoU score.",min_width=1536).style(grid=6)
gallery_output2=gr.Gallery(label="Image Gallery sorted by mask area.",min_width=1536).style(grid=6)
text=gr.components.Textbox(label="Categories. (The default is the categories in COCO panoptic segmentation.)",value=":".join(all_classes),visible=False)
text_part=gr.components.Textbox(label="Part Categories. (The default is the categories in PASCAL Part.)",value=":".join(all_parts),visible=False)
text_res=gr.components.Textbox(label="\"class:part(score)\" of all predictions (seperated by ;): ",visible=True)
text_thresh=gr.components.Textbox(label="The threshold to filter masks with low iou score.",value="0.5",visible=True)
text_size=gr.components.Textbox(label="image size (shortest edge)",value="640",visible=True)
hole_scale=gr.components.Textbox(label="holes scale",value="100",visible=True)
island_scale=gr.components.Textbox(label="island scale",value="100",visible=True)
text_model_size=gr.components.Textbox(label="model size (L or T)",value="L",visible=True)
text_ckpt=gr.components.Textbox(label="ckpt path (relative to /mnt/output/)",value="fengli/joint_part_idino/train_interactive_all_m2m_swinL_bs16_0.1part9_nohash_bs1_resume_all_local_0.15_onlysa_swinL_4node_mnode/model_0099999.pth",visible=True)
text_ckpt_now=gr.components.Textbox(label="current ckpt path (relative to /mnt/output/)",value="",visible=True)
semantic=gr.Checkbox(label="Semantics", info="Do you use semantic? (The semantic model in the demo is trained on SA-1B, COCO and PASCAL Part.)")title='''
# Semantic-SAM: Segment and Recognize Anything at Any Granularity# [[Read our arXiv Paper](https://arxiv.org/pdf/2307.04767.pdf)\] \[[Github page](https://github.com/UX-Decoder/Semantic-SAM)\] # Please only click one point, or our model will take the center of all points as the clicked location. Remember to clear the click after each interaction, or we will take the center of the current click and previous ones as the clicked location.
'''
def change_vocab(choice):if choice:return gr.update(visible=True)else:return gr.update(visible=False)with demo:with gr.Row():with gr.Column(scale=9.0):generation_tittle = gr.Markdown(title)# generation_tittle.render()with gr.Row(scale=20.0):image.render()example = gr.Examples(examples=[["examples/tank.png"],["examples/castle.png"],["examples/fries1.png"],["examples/4.png"],["examples/5.png"],["examples/corgi2.jpg"],["examples/minecraft2.png"],["examples/ref_cat.jpeg"],["examples/img.png"],],inputs=image,cache_examples=False,)with gr.Row(scale=1.0):with gr.Column():text_thresh.render()with gr.Row(scale=2.0):clearBtn = gr.ClearButton(components=[image])runBtn = gr.Button("Run")with gr.Row(scale=6.0):text.render()with gr.Row(scale=1.0):text_part.render()gallery_tittle = gr.Markdown("# The masks sorted by IoU scores (masks with low score may have low quality).")with gr.Row(scale=9.0):gallery_output.render()gallery_tittle1 = gr.Markdown("# The masks sorted by mask areas.")with gr.Row(scale=9.0):gallery_output2.render()title = title,article = article,allow_flagging = 'never',runBtn.click(inference, inputs=[image, text, text_part,text_thresh],outputs = [gallery_output,gallery_output2])demo.queue().launch(share=True,server_port=6082)
模型输出:
GT值:
总结
Semantic-SAM 是一款基于 Mask DINO 框架改进的突破性图像分割模型,专门针对先前模型在语义感知和多粒度分割方面的不足进行了优化。通过整合多个数据集并采用多选择学习方案,该模型能够在不同粒度级别上实现高质量的分割,并为分割结果提供准确的语义标签。Semantic-SAM 在通用分割、细粒度分割和交互式分割任务中表现出色,显著提升了分割精度并拓展了应用场景,为计算机视觉领域提供了一个强大的工具,能够实现任意粒度级别的分割与识别。