当前位置: 首页 > ai >正文

基于自定义数据集微调SigLIP2-分类任务

  本项目基于Google的SigLIP2模型,构建了一个智能xx等级分类系统。通过联合训练策略(对比学习+分类学习),实现了对xx图像的精确等级分类(Grade 2-5),提供AI辅助支持。

一、任务背景

xx等级分类的重要性

  xx等级的准确判断对后续方案制定和预后评估至关重要:
- Grade 1:正常情况,目前不纳入分类内容。
- Grade 2:等级2
- Grade 3:等级3
- Grade 4:等级4
- Grade 5:等级5

技术挑战

  图像特征复杂,等级边界模糊,传统方法依赖专家经验,主观性强,需要同时理解图像内容和xx描述。

二、数据集构建

  • 标注格式:JSONL格式,包含图像路径和文本描述
  • 数据分布:多等级xx图像,每张图像配有详细的描述

数据标注

CVAT对图片进行tag标注,自定义标注工具进行jsonl标注文件生成。

{"image_path": "images/36_2650.jpg", "text": "Grade 3: Presence of ANY of the following: description."}
{"image_path": "images/36_2675.jpg", "text": "Grade 3: Presence of ANY of the following: description."}
dataset/
├── images/
│   ├── img001.jpg
│   ├── img002.jpg
├── labels.jsonl

数据预处理

python
# 数据加载和标签提取
def load_data(jsonl_path, image_dir):with open(jsonl_path, "r") as f:entries = [json.loads(line.strip()) for line in f]data = []for entry in entries:image_path = os.path.join(image_dir, os.path.basename(entry["image_path"]))text = entry["text"]label = extract_label(text)  # 从文本中提取等级标签if label != -1:data.append((image_path, text, label))return data# 标签映射id2grade = {0: "Grade 2", 1: "Grade 3", 2: "Grade 4", 3: "Grade 5"}

三、模型架构

核心模型

  • 基础模型:SigLIP2-Base-Patch16-384
  • 输入尺寸:384×384像素
  • 预训练权重:Google官方预训练模型,自行到hugging face下载
    在这里插入图片描述

联合训练架构

python
class SigLIP2WithClassifier(nn.Module):def __init__(self, base_model, processor, num_classes=4):self.siglip = base_model          # SigLIP2主干网络self.classifier = nn.Linear(embed_dim, num_classes)  # 分类头self.temperature = 0.07           # 对比学习温度参数

损失函数设计

  1. 对比损失(Contrastive Loss)

    • 目标:学习图像-文本对应关系
    • 公式:logits_per_image = (image_embeds @ text_embeds.T) / temperature
  2. 分类损失(Classification Loss)

    • 目标:精确预测烧伤等级
    • 公式:CrossEntropy(classifier(image_embeds), class_labels)
  3. 联合损失

    • 总损失 = 对比损失 + 分类损失

训练配置

python
# 训练机器:H100服务器
# 训练参数
epochs = 20
learning_rate = 2e-5
batch_size = 16
device = "cuda"  # GPU训练

模型保存

torch.save(model.state_dict(), os.path.join(save_dir, "parkland_siglip2.pt"))

在这里插入图片描述

三、推理部署

模型加载

python
# 加载训练好的模型
model = SigLIP2WithClassifier(base_model, processor, num_classes=4)
model.load_state_dict(torch.load("parkland_siglip2.pt"))
model.eval()

推理流程

  1. 图像预处理:调整尺寸、标准化
  2. 特征提取:通过SigLIP2获取图像嵌入
  3. 分类预测:通过分类头预测烧伤等级
  4. 结果输出:返回等级概率分布
 # 加载和预处理图像image = Image.open(image_path).convert("RGB")inputs = processor(images=image, return_tensors="pt").to(device)# 推理image_features = model.siglip.vision_model(pixel_values=inputs["pixel_values"]).pooler_outputlogits = model.classifier(image_features)probs = torch.softmax(logits, dim=-1)pred = probs.argmax(dim=-1).item()

在这里插入图片描述

http://www.xdnf.cn/news/15779.html

相关文章:

  • PDF 编辑器:多文件合并 拆分 旋转 顺序随便调 加水印 密码锁 页码背景
  • [学习] 深入理解傅里叶变换:从时域到频域的桥梁
  • vscode环境下c++的常用快捷键和插件
  • 嵌入式通信DQ单总线协议及UART(一)
  • Linux练习二
  • 鸿蒙蓝牙通信
  • [AI风堇]基于ChatGPT3.5+科大讯飞录音转文字API+GPT-SOVITS的模拟情感实时语音对话项目
  • 字节跳动开源Seed-X 7B多语言翻译模型:28语种全覆盖,性能超越GPT-4、Gemini-2.5与Claude-3.5
  • 关于Vuex
  • GeoPandas 城市规划:Python 空间数据初学者指南
  • 零基础 “入坑” Java--- 十二、抽象类和接口
  • ndexedDB 与 LocalStorage:全面对比分析
  • aosp15实现SurfaceFlinger的dump输出带上Layer详细信息踩坑笔记
  • EP01:【Python 第一弹】基础入门知识
  • Vue rem回顾
  • 文档表格标题跑到表格下方,或标题跟表格空隔太大如何处理
  • Java无服务架构新范式:Spring Native与AWS Lambda冷启动深度优化
  • Flutter基础(前端教程①⑤-API请求转化为模型列成列表展示实战)
  • 财务数字化——解读财务指标及财务分析的基本步骤与方法【附全文阅读】
  • Error:HTTP Status 405 - HTTP method POST is not supported by this URL
  • 大数据之路:阿里巴巴大数据实践——日志采集与数据同步
  • 短视频矩阵的未来前景:机遇无限,挑战并存
  • [spring6: Advice Advisor Advised]-快速理解
  • stm32继电器使用方法
  • 【HarmonyOS】Ability Kit - Stage模型
  • 2023 年 5 月青少年软编等考 C 语言八级真题解析
  • 安装tomcat启动startup.bat出现闪退问题
  • 驾驭 Spring Boot 事件机制:8 个内置事件 + 自定义扩展实战
  • windows wsl ubuntu 如何安装 maven
  • 前端知识回顾-登录界面