当前位置: 首页 > news >正文

Bilateral Reference for High-Resolution Dichotomous Image Segmentation

代码来源

https://github.com/ZhengPeng7/BiRefNet

模块作用

DIS 是一种旨在对高分辨率图像中的目标物体进行精确分割的技术,尤其适用于具有复杂细微结构的物体,例如细长的边缘或微小细节。传统方法在处理这类任务时往往难以捕捉细微特征或恢复高分辨率细节,因此论文提出了一种新颖的网络架构BiRefNet以解决这些挑战。

模块结构

定位模块(LM)
  • 输入高分辨率图像至视觉变换器骨干网络。
  • 提取多尺度的层次特征,捕捉全局语义信息。
  • 通过特征融合和压缩,生成低分辨率的粗略预测图。
  • 原理:利用变换器的全局建模能力,在低分辨率下快速定位目标物体,避免直接处理高分辨率带来的计算负担。
重建模块(RM)
  • 接收定位模块输出的低分辨率粗略预测图。
  • 在解码器的多个阶段,逐步上采样并结合双边参考信息。
  • 输出高分辨率的精细分割图。
  • 原理:通过将原始图像的分块输入解码器,提供高分辨率的细节参考,确保重建过程中细节不丢失。通过梯度图的监督,引导模型聚焦于边缘和细微结构,避免模糊或遗漏关键区域。从低分辨率到高分辨率的分阶段上采样,确保全局一致性和局部精确性的平衡。

代码

class BiRefNet(nn.Module,PyTorchModelHubMixin,library_name="birefnet",repo_url="https://github.com/ZhengPeng7/BiRefNet",tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
):def __init__(self, bb_pretrained=True):super(BiRefNet, self).__init__()self.config = Config()self.epoch = 1self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)channels = self.config.lateral_channels_in_collectionif self.config.auxiliary_classification:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.cls_head = nn.Sequential(nn.Linear(channels[0], len(class_labels_TR_sorted)))if self.config.squeeze_block:self.squeeze_module = nn.Sequential(*[eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])for _ in range(eval(self.config.squeeze_block.split('_x')[1]))])self.decoder = Decoder(channels)if self.config.ender:self.dec_end = nn.Sequential(nn.Conv2d(1, 16, 3, 1, 1),nn.Conv2d(16, 1, 3, 1, 1),nn.ReLU(inplace=True),)# refine patch-level segmentationif self.config.refine:if self.config.refine == 'itself':self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')else:self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))if self.config.freeze_bb:# Freeze the backbone...print(self.named_parameters())for key, value in self.named_parameters():if 'bb.' in key and 'refiner.' not in key:value.requires_grad = Falsedef forward_enc(self, x):if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)else:x1, x2, x3, x4 = self.bb(x)if self.config.mul_scl_ipt:B, C, H, W = x.shapex_pyramid = F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)if self.config.mul_scl_ipt == 'cat':if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:x1_ = self.bb.conv1(x_pyramid); x2_ = self.bb.conv2(x1_); x3_ = self.bb.conv3(x2_); x4_ = self.bb.conv4(x3_)else:x1_, x2_, x3_, x4_ = self.bb(x_pyramid)x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)elif self.config.mul_scl_ipt == 'add':x1_, x2_, x3_, x4_ = self.bb(x_pyramid)x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else Noneif self.config.cxt:x4 = torch.cat((*[F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),][-len(self.config.cxt):],x4),dim=1)return (x1, x2, x3, x4), class_predsdef forward_ori(self, x):########## Encoder ##########(x1, x2, x3, x4), class_preds = self.forward_enc(x)if self.config.squeeze_block:x4 = self.squeeze_module(x4)########## Decoder ##########features = [x, x1, x2, x3, x4]if self.training and self.config.out_ref:features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))scaled_preds = self.decoder(features)return scaled_preds, class_predsdef forward(self, x):scaled_preds, class_preds = self.forward_ori(x)class_preds_lst = [class_preds]return [scaled_preds, class_preds_lst] if self.training else scaled_preds

总结

本文提出了一个配备双边参考的 BiRefNet 框架,该框架可在同一框架内执行二分图像分割、高分辨率显著目标检测和隐藏目标检测。通过全面的实验,研究者发现未缩放的源图像和对信息丰富区域的关注对于生成 HR 图像中精细且细节丰富的区域至关重要。为此,研究者提出了双边参考来填充精细部分中缺失的信息(内向参考),并引导模型更加关注细节更丰富的区域(外向参考)。这显著提升了模型捕捉微小像素特征的能力。为了降低 HR 数据训练的高昂训练成本,本文还提供了各种实用技巧,以实现更高质量的预测和更快的收敛速度。在 13 个基准测试中取得的优异结果证明了BiRefNet 的卓越性能和强大的泛化能力。

http://www.xdnf.cn/news/1248967.html

相关文章:

  • 智慧社区(八)——社区人脸识别出入管理系统设计与实现
  • 轻量应用服务器Centos系统上安装jdk8和Jdk17教程(详细)
  • (ZipList入门笔记二)为何ZipList可以实现内存压缩,可以详细介绍一下吗
  • 在AI时代,如何制定有效的职业规划?AI时代职业规划+AI产品经理角色
  • 探索设计模式的宝库:Java-Design-Patterns
  • FastDeploy2.0:报qwen2.embed_tokens.weight
  • 3. 为什么 0.1 + 0.2 != 0.3
  • 多传感器融合
  • Redis之Set和SortedSet类型常用命令
  • leetcode-python-删除链表的倒数第 N 个结点
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-邮箱重置密码
  • 使用ProxySql实现MySQL的读写分离
  • ubuntu24安装vulkan-sdk
  • 一文搞定JavaServerPages基础,从0开始写一个登录与人数统计页面
  • Rust进阶-part4-智能指针2
  • 力扣106:从中序与后序遍历序列构造二叉树
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-登录实现
  • Redis里面什么是sdshdr,可以详细介绍一下吗?
  • Linux lvm逻辑卷管理
  • 跑yolov5的train.py时,ImportError: Failed to initialize: Bad git executable.
  • 【Linux】特效爆满的Vim的配置方法 and make/Makefile原理
  • 一种红外遥控RGB灯带控制器-最低价MCU
  • MySQL间隙锁在查询时锁定的范围
  • 前端遇到页面卡顿问题,如何排查和解决?
  • 【运维部署篇】OpenShift:企业级容器应用平台全面解析
  • Spring 的优势
  • Springboot集成Log4j2+MDC串联单次请求的日志
  • HBM Basic(VCU128)
  • 《Python基础》第3期:使用PyCharm编写Hello World
  • Leetcode-2080区间内查询数字的频率