当前位置: 首页 > java >正文

【TrOCR】用Transformer和torch库实现TrOCR模型

项目结构:

TrOCR/
├── config.py               # 所有配置参数(路径、超参数等)
├── dataset.py              # 数据集类 + 数据增强(合并 data_augmentation)
├── model.py                # 模型加载与配置
├── utils.py                # 通用工具(含日志功能,合并 logger.py)
├── train.py                # 训练逻辑 + 入口(合并 trainer.py)
├── predict.py              # 推理接口 + 入口(合并 inference.py)
├── evaluate.py             # 评估指标 + 入口(合并 metrics.py)
├── requirements.txt        # 依赖库
├── README.md               # 项目说明
├── data/                   # 数据集
│   ├── train/(images + labels.json)
│   ├── val/(images + labels.json)
│   └── test/(images + labels.json)
├── models/                 # 保存训练好的模型
└── logs/                   # 训练日志、评估报告

数据集的数据结构

我的数据集路径:C:\Users\Virgil\Desktop\dataetOCR\ChineseOcr2k,目录下有train和val两个文件夹,分别是images和labels.json。

标签JSON文件的数据结构:labels.json内容是这样的:

    [{"file_name": "20587062_124836763.jpg","text": "设施一流的绿色、舒适"},{"file_name": "20487921_757563219.jpg","text": "瘦削,二○○六年五月"},{"file_name": "20567468_1494490742.jpg","text": "分置改革方案》在法规"...

损失函数:
TrOCR 官方使用的损失函数是交叉熵损失(Cross-Entropy Loss),
主要用于计算解码器生成文本与真实标签之间的差异,具体是通过 标签移位(label shifting) 策略实现的序列到序列(Seq2Seq)损失计算。
TrOCR 是典型的编码器 - 解码器架构(图像编码器 + 文本解码器),
其损失计算逻辑与大多数 Seq2Seq 模型一致:

  • 输入与标签设计:解码器的输入是 “真实文本标签左移一位 + 起始符号(如 [CLS])”,标签是 “真实文本标签 + 终止符号(如 [SEP])”。
  • 损失计算:对解码器每个时间步的输出 logits 计算交叉熵损失,忽略 padding 位置(通过将 pad token 替换为 -100 实现,PyTorch 会自动忽略 -100 标签的损失)。
http://www.xdnf.cn/news/18405.html

相关文章:

  • Matplotlib+HTML+JS:打造可交互的动态数据仪表盘
  • 智慧工厂的 “隐形大脑”:边缘计算网关凭什么重构设备连接新逻辑?
  • 详细说明http协议特别是conten-length和chunk编码,并且用linux的命令行演示整个过程
  • Go语言变量声明与初始化详解
  • 一个状态机如何启动/停止另一个状态机
  • 【机器学习 / 深度学习】基础教程
  • StarRocks不能启动 ,StarRocksFe节点不能启动问题 处理
  • 生信分析自学攻略 | R语言函数与参数介绍
  • Notepad++换行符替换
  • 造成云手机闪退的原因有哪些?
  • HarmonyOS 实战:6 种实现实时数据更新的方案全解析(含完整 Demo)
  • java18学习笔记-Simple Web Server
  • 【LeetCode 415】—字符串相加算法详解
  • 【数据可视化-96】使用 Pyecharts 绘制主题河流图(ThemeRiver):步骤与数据组织形式
  • 深度学习-168-MCP技术之VSCode中安装插件Cline客户端应用MCP Server工具
  • 计算机网络-1——第一阶段
  • 在.NET 8 中使用中介模式优雅处理多版本 API 请求
  • 【51单片机】【protues仿真】基于51单片机16键电子琴系统
  • 高可用操作步骤
  • 纷玩岛协议抢票免费源码
  • Spring两个核心IoCDI(一)
  • java基础(十三)消息队列
  • #千问海报大赛
  • ORACLE中如何批量重置序列
  • 常德二院全栈国产化实践:KingbaseES 数据库的关键作用
  • PyTorch数据处理工具箱(可视化工具)
  • 大模型0基础开发入门与实践:第11章 进阶:LangChain与外部工具调用
  • Building Systems with the ChatGPT API 使用 ChatGPT API 搭建系统(第四章学习笔记及总结)
  • Eino 框架组件协作指南 - 智能图书馆建设手册
  • RAG学习(四)——使用混合检索进行检索优化