【TrOCR】根据任务特性设计词表vocab.json
预训练权重的限制
基于对
microsoft/trocr-base-printed
预训练模型文件夹中的各个文件的理解,思考以下问题:
1、你觉得哪些文件在TrOCR的训练中起到作用,并简短说明是什么作用。
2、哪些文件经过TrOCR的训练调参后发生改变(也就是说通过训练得到的,而不是人物设计的)。
3、一些对于TrOCR模型的训练和推理都没有作用的文件,是在哪里起作用,是识别完文字的下一个阶段用到?
4、训练不改变内容,在仅仅用于文字识别TrOCR任务的训练推理时,不起作用的文件是否就可以删除?
5、预训练权重是这些,那微调后的权重文件夹内容也是这些吗?
1、在TrOCR训练中起作用的文件及作用
config.json
:定义模型整体架构(编码器/解码器层数、维度等),是训练时模型结构的“蓝图”,决定网络参数的初始化和计算逻辑。preprocessor_config.json
:规定图像预处理规则(缩放、归一化等),确保训练数据的输入格式与模型要求一致。tokenizer_config.json
、vocab.json
、merges.txt
、special_tokens_map.json
:共同定义文本分词逻辑(词表、子词合并规则、特殊标记),将训练文本转换为模型可理解的Token ID。model.safetensors
:存储初始预训练权重,训练中会基于新数据更新这些权重参数。generation_config.json
:控制训练时文本生成的策略(如起始/结束标记、生成长度限制),影响模型的拟合方向。
2、训练调参后发生改变的文件(训练得到,非人工设计)
仅 model.safetensors
。
该文件存储模型的所有权重参数(编码器/解码器的注意力权重、层归一化参数等),训练过程中通过反向传播不断更新这些参数,是唯一由训练直接生成的文件。其他文件(如配置文件、分词文件)均为人工设计的“规则/结构定义”,训练中不会被修改。
3、对训练和推理无直接作用的文件及作用场景
README.md
:是模型的说明文档,用于指导用户理解模型用途、使用方法、引用信息等,不参与训练/推理计算,仅在用户学习和使用模型时提供参考。.gitattributes
:是Git版本控制的配置文件,用于指定大文件(如model.safetensors
)由Git LFS管理,确保仓库高效存储和协作,与模型的训练/推理功能无关,仅作用于代码仓库的版本管理环节。
这些文件与“文字识别后的下阶段”无关,主要服务于模型的分发、管理和用户指导。
4、训练不改变内容且对识别任务无直接作用的文件是否可删除?
不建议删除。
README.md
是快速上手模型的关键参考,删除后用户难以了解模型用法和限制;.gitattributes
缺失会导致Git无法正确处理大文件(如model.safetensors
可能被当作普通文件存储,导致仓库膨胀),影响后续版本管理和协作。
仅在极端极简部署场景(如嵌入式设备),且明确无需版本管理和用户指导时,可考虑删除,但需谨慎。
5、预训练权重与微调后权重的文件夹内容是否相同?
基本相同,核心文件结构一致。
微调仅更新 model.safetensors
(权重参数),可能会调整 generation_config.json
(若微调时修改了生成策略),但其他文件(config.json
、分词文件、预处理配置等)均保持不变——因为微调是在预训练模型的结构基础上调整权重,不会改变模型的架构、数据处理规则等“底层设计”。
模型前向传播
查看模型结构和词表长度
现在我print出了加载预训练权重模型结构:
from transformers import VisionEncoderDecoderModelPRETRAINED_MODEL_PATH = r"C:\Users\Virgil\Desktop\TrOCR\trocr-base-printed"
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_PATH)print("模型结构")
print(model)print("------------------------------------------")
model.config.vocab_size = model.config.decoder.vocab_sizeprint("词表长度:")
print(model.config.vocab_size)
VisionEncoderDecoderModel((encoder): ViTModel((embeddings): ViTEmbeddings((patch_embeddings): ViTPatchEmbeddings((projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)))(dropout): Dropout(p=0.0, inplace=False))(encoder): ViTEncoder((layer): ModuleList((0-11): 12 x ViTLayer((attention): ViTAttention((attention): ViTSelfAttention((query): Linear(in_features=768, out_features=768, bias=False)(key): Linear(in_features=768, out_features=768, bias=False)(value): Linear(in_features=768, out_features=768, bias=False))(output): ViTSelfOutput((dense): Linear(in_features=768, out_features=768, bias=True)(dropout): Dropout(p=0.0, inplace=False)))(intermediate): ViTIntermediate((dense): Linear(in_features=768, out_features=3072, bias=True)(intermediate_act_fn): GELUActivation())(output): ViTOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(dropout): Dropout(p=0.0, inplace=False))(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True))))(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(pooler): ViTPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh()))(decoder): TrOCRForCausalLM((model): TrOCRDecoderWrapper((decoder): TrOCRDecoder((embed_tokens): TrOCRScaledWordEmbedding(50265, 1024, padding_idx=1)(embed_positions): TrOCRLearnedPositionalEmbedding(514, 1024)(layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layers): ModuleList((0-11): 12 x TrOCRDecoderLayer((self_attn): TrOCRAttention((k_proj): Linear(in_features=1024, out_features=1024, bias=True)(v_proj): Linear(in_features=1024, out_features=1024, bias=True)(q_proj): Linear(in_features=1024, out_features=1024, bias=True)(out_proj): Linear(in_features=1024, out_features=1024, bias=True))(activation_fn): GELUActivation()(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(encoder_attn): TrOCRAttention((k_proj): Linear(in_features=768, out_features=1024, bias=True)(v_proj): Linear(in_features=768, out_features=1024, bias=True)(q_proj): Linear(in_features=1024, out_features=1024, bias=True)(out_proj): Linear(in_features=1024, out_features=1024, bias=True))(encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(fc1): Linear(in_features=1024, out_features=4096, bias=True)(fc2): Linear(in_features=4096, out_features=1024, bias=True)(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)))))(output_projection): Linear(in_features=1024, out_features=50265, bias=False))
)
打印出的词表长度:50265
前向传播与输出logits
讲解
每一轮训练
# 前向传播(获取解码器输出logits)outputs = model(pixel_values=pixel_values, labels=labels)logits = outputs.logits # 形状: (batch_size, max_length, vocab_size)
在TrOCR训练的每一轮前向传播中,model(pixel_values=pixel_values, labels=labels)
是核心计算过程,结合模型结构可拆解为以下步骤,最终得到的 logits
是模型预测的核心输出:
1. model()
的输入解析
pixel_values
:预处理后的图像张量,形状通常为(batch_size, 3, 384, 384)
(批量大小、RGB通道、图像尺寸)。由preprocessor_config.json
定义的规则处理(缩放、归一化等),确保与编码器输入要求一致。labels
:图像对应的文本标签经分词后的Token ID张量,形状为(batch_size, max_length)
(max_length
为文本最大长度,短文本用<pad>
填充)。用于计算预测损失,指导模型参数更新。
2. 模型前向传播过程(结合结构细节)
模型按“编码器→解码器”流程处理输入,最终输出 logits
:
(1)编码器(ViTModel)处理图像
- 图像分块与嵌入:
输入图像先经ViTEmbeddings.patch_embeddings
(16x16卷积)分割为 24x24=576 个补丁(384/16=24),每个补丁被投影为 768 维向量(与config.json
中编码器hidden_size=768
对应),再添加位置嵌入并经 dropout 处理。 - Transformer编码:
嵌入后的补丁序列(576个向量)输入ViTEncoder
的12层ViTLayer
:- 每层通过
ViTAttention
(多头自注意力)捕捉补丁间的空间关系; - 经
ViTIntermediate
(3072维前馈网络)和ViTOutput
转换特征,配合两层LayerNorm(layernorm_before
/after
)稳定训练。
- 每层通过
- 输出图像特征:
编码器最终输出形状为(batch_size, 576, 768)
的图像特征(批量、补丁数、特征维度),作为解码器的“视觉输入”。
(2)解码器(TrOCRForCausalLM)生成文本预测
解码器以“图像特征+文本标签”为输入,通过自回归逻辑生成文本预测:
- 文本嵌入:
labels
(文本Token ID)先经TrOCRScaledWordEmbedding
转换为 1024 维词向量(与解码器d_model=1024
对应),添加TrOCRLearnedPositionalEmbedding
(位置信息)并经LayerNorm处理。 - Transformer解码(12层
TrOCRDecoderLayer
):
每层包含两个关键注意力机制:- 自注意力(self_attn):关注已生成的文本Token(如第i步关注前i-1个Token),通过
k_proj/q_proj/v_proj
计算1024维注意力分布,捕捉文本序列的上下文关系; - 交叉注意力(encoder_attn):关注编码器输出的图像特征(768维),通过
k_proj/v_proj
将图像特征转换为1024维(与解码器维度匹配),再与文本特征计算注意力,确保生成的文本与图像内容对齐。
每层还通过fc1
(1024→4096)和fc2
(4096→1024)的前馈网络增强特征表达,配合LayerNorm稳定训练。
- 自注意力(self_attn):关注已生成的文本Token(如第i步关注前i-1个Token),通过
- 输出投影:
解码器最终输出的文本特征(1024维)经output_projection
(线性层)映射到词表空间,得到logits
。
3. logits
的含义与形状
- 形状:
(batch_size, max_length, vocab_size)
,对应(批量大小, 文本长度, 词表大小)
。其中vocab_size=50265
(与vocab.json
词表大小一致)。 - 含义:
logits[i, j, k]
表示“第i个样本、第j个位置的Token为词表中第k个词”的原始预测分数(未归一化的概率)。后续会通过softmax转换为概率分布,用于计算与labels
的交叉熵损失(衡量预测与真实标签的差异)。
这一步前向传播的核心是:编码器将图像转换为视觉特征,解码器结合视觉特征和文本标签生成逐位置的词表预测(logits)。logits
既是预测结果的直接输出,也是计算损失、反向传播更新 model.safetensors
中权重参数的关键依据,最终让模型学习到“图像→文本”的映射规律。
词表设计思路与分析
词表vocab文件里的各个token很多都不是单个字母或者单个数字。这个OCR任务就是单纯的识别出是什么字母数字或者汉字。那词表没必要这么复杂。可以设计为仅保留 “单个字符 + 核心特殊 Token” 的词表设计(总规模几千个)
- 核心特殊 Token
- 常用的几千个中文汉字,
- 52个英文字母
- 10个数字
- 各个特殊符号
这种仅保留“单个字符+核心特殊Token”的词表设计(总规模几千个)在特定场景下有其合理性,但也会带来显著的优缺点,且需要配套调整模型和训练策略,具体分析如下:
一、这种设计的核心优势
-
词表规模大幅减小
从50265缩减到几千,直接降低模型的词嵌入层参数规模(原解码器词嵌入层参数为50265×1024≈5100万
,新设计为5000×1024≈512万
,参数减少90%),可降低内存占用和计算成本,适合资源有限的场景(如嵌入式设备)。 -
分词逻辑简化
无需BPE分词的复杂合并规则(merges.txt
可省略),文本直接按单个字符拆分(如“apple”拆分为a→p→p→l→e
),分词过程更直观,减少因子词拆分导致的识别误差(如特殊组合被错误拆分)。 -
适配“纯字符识别”场景
若OCR任务仅针对高度标准化的文本(如身份证、发票上的固定格式字符,无复杂缩写、生僻组合),这种词表可覆盖所有可能输入,避免子词表带来的冗余。
二、潜在问题与局限性
-
丢失“上下文依赖”捕捉能力
原词表的子词(如“app”“ing”)本质是对“高频字符组合”的预编码,模型可通过这些子词快速学习字符间的关联(如“q”后几乎必接“u”)。而纯字符词表中,每个字符独立拆分,模型需要从头学习所有组合规律,对模糊、倾斜或噪声图像的鲁棒性会下降(例如“th”在模糊图像中,子词表模型可能通过整体特征识别,而纯字符模型可能因单个“t”或“h”模糊导致错误)。 -
预训练权重无法直接复用
原预训练模型的解码器(基于RoBERTa)是为50265词表设计的,其词嵌入层、注意力权重等参数均与该词表绑定。换成新词表后:- 词嵌入层维度不匹配(需重新初始化或裁剪,导致预训练的“语言知识”丢失);
- 解码器的交叉注意力、前馈网络等参数是基于子词分布训练的,换成纯字符后,参数分布与新数据不匹配,需要从头训练或大规模微调,成本显著增加。
-
对复杂文本的覆盖不足
三、适用场景与建议
- 适合场景:高度标准化、字符有限的封闭场景(如快递单号识别、车牌识别,字符集固定且无复杂组合)。
- 不适合场景:通用OCR(如书籍、网页、手写体),或包含复杂文本结构的场景。
若坚持使用这种设计,需配套调整:
- 重新初始化解码器的词嵌入层(适配新词表大小);
- 增加训练数据中字符级别的增强(如模糊、旋转、噪声),帮助模型学习单个字符的鲁棒特征;
- 移除原词表相关文件(
merges.txt
不再需要),仅保留字符映射的vocab.json
和特殊Token配置。
这种“极简字符词表”是对特定场景的优化,而非通用方案。它通过牺牲对复杂文本的适应能力,换取参数规模和计算效率的提升,但需要重新训练模型,且无法复用原预训练的子词级语言知识。是否采用,需根据具体任务的文本复杂度和资源约束决定。
词表生成
52个英文字母和10个数字很简单,
特殊Token就直接用预训练权重的前4个特殊token:
"<s>": 0,"<pad>": 1,"</s>": 2,"<unk>": 3,
汉字和特殊符号需要找:
5020 个常用汉字资源文件:收录于 “常用汉字大全.txt”,项目地址为 https://gitcode.com/Open-source-documentation-tutorial/4afa4
https://github.com/wy-luke/All-Chinese-Character-Set/blob/main/symbols.txt
https://github.com/DenverCoder1/latex-gboard-dictionary/blob/master/dictionary.txt
import jsondef read_chars_from_file(file_path):"""从文件中读取字符,每行一个字符"""chars = []try:with open(file_path, 'r', encoding='utf-8') as f:for line in f:# 去除每行的换行符,但保留原始字符(包括空格、tab等)char = line.rstrip('\n')if char: # 确保不添加空字符串chars.append(char)except FileNotFoundError:print(f"警告:未找到文件 {file_path},将跳过该文件的处理")return chars# 存储所有被跳过的字符
skipped_chars = []# 1. 定义核心特殊token(按优先级排序)
special_tokens = ["<s>", # 起始标记"<pad>", # 填充标记"</s>", # 结束标记"<unk>" # 未知标记
]
special_count = len(special_tokens)
special_start_idx = 0
special_end_idx = special_count - 1# 2. 从文件读取字符
symbols = read_chars_from_file("symbols_2.txt")
symbol_count = len(symbols)
symbol_start_idx = special_end_idx + 1 if special_count > 0 else 0
symbol_end_idx = symbol_start_idx + symbol_count - 1 if symbol_count > 0 else -1math_chars = read_chars_from_file("math.txt")
original_math_count = len(math_chars)chinese_chars = read_chars_from_file("chinese_5021.txt")
original_chinese_count = len(chinese_chars)# 3. 处理数学符号:去除与特殊符号重复的字符
unique_math_chars = []
symbol_set = set(symbols)
skipped_math_count = 0for char in math_chars:if char not in symbol_set and char not in unique_math_chars:unique_math_chars.append(char)else:skipped_math_count += 1skipped_chars.append(char) # 记录被跳过的数学符号math_count = len(unique_math_chars)
math_start_idx = symbol_end_idx + 1 if symbol_count > 0 else special_end_idx + 1
math_end_idx = math_start_idx + math_count - 1 if math_count > 0 else -1# 4. 处理中文字符:确保唯一性
existing_chars = set(special_tokens + symbols + unique_math_chars) # 前面字符集合
chinese_internal_duplicates = 0 # 中文内部重复计数
chinese_external_duplicates = 0 # 与其他文件重复计数
unique_chinese_chars = []for char in chinese_chars:if char in existing_chars:# 与前面的特殊符号/数学符号等重复chinese_external_duplicates += 1skipped_chars.append(char) # 记录与其他文件重复的中文elif char in unique_chinese_chars:# 中文内部重复chinese_internal_duplicates += 1skipped_chars.append(char) # 记录中文内部重复的字符else:unique_chinese_chars.append(char)existing_chars.add(char)# 总跳过数量 = 内部重复 + 外部重复
skipped_chinese_count = chinese_internal_duplicates + chinese_external_duplicateschinese_count = len(unique_chinese_chars)
chinese_start_idx = math_end_idx + 1 if math_count > 0 else symbol_end_idx + 1
chinese_end_idx = chinese_start_idx + chinese_count - 1 if chinese_count > 0 else -1# 5. 合并所有字符,保持指定顺序
all_chars = special_tokens + symbols + unique_math_chars + unique_chinese_chars
total_count = len(all_chars)# 6. 生成词表字典(字符到索引的映射)
vocab = {char: idx for idx, char in enumerate(all_chars)}# 7. 保存为JSON文件
with open("vocab_2.json", 'w', encoding='utf-8') as f:json.dump(vocab, f, ensure_ascii=False, indent=2)# 8. 保存所有被跳过的字符到skip.txt
with open("skip.txt", 'w', encoding='utf-8') as f:for char in skipped_chars:f.write(char + '\n')# 9. 输出统计信息
print("=" * 50)
print("词表生成统计信息:")
print("=" * 50)
print(f"特殊Token:共 {special_count} 个")
print(f" 索引区间:[{special_start_idx}, {special_end_idx}]")
print(f" 内容:{special_tokens}")
print("-" * 50)
print(f"特殊符号:共 {symbol_count} 个")
print(f" 索引区间:[{symbol_start_idx}, {symbol_end_idx}]")
print("-" * 50)
print(f"数学符号:")
print(f" 原始数量:{original_math_count} 个")
print(f" 写入数量:{math_count} 个(去重后)")
print(f" 跳过数量:{skipped_math_count} 个(与特殊符号重复)")
print(f" 索引区间:[{math_start_idx}, {math_end_idx}]")
print("-" * 50)
print(f"中文字符:")
print(f" 原始数量:{original_chinese_count} 个")
print(f" 写入数量:{chinese_count} 个(去重后)")
print(f" 与其他文件重复:{chinese_external_duplicates} 个")
print(f" 中文内部重复:{chinese_internal_duplicates} 个")
print(f" 总跳过数量:{skipped_chinese_count} 个")
print(f" 索引区间:[{chinese_start_idx}, {chinese_end_idx}]")
print("-" * 50)
print(f"所有被跳过的字符已保存到:skip.txt(共 {len(skipped_chars)} 个)")
print(f"词表总字符数:{total_count} 个")
print(f"词表已保存为:vocab_3.json")
print("=" * 50)
==================================================
词表生成统计信息:
==================================================
特殊Token:共 4 个索引区间:[0, 3]内容:['<s>', '<pad>', '</s>', '<unk>']
--------------------------------------------------
特殊符号:共 128 个索引区间:[4, 131]
--------------------------------------------------
数学符号:原始数量:919 个写入数量:708 个(去重后)跳过数量:211 个(与特殊符号重复)索引区间:[132, 839]
--------------------------------------------------
中文字符:原始数量:5021 个写入数量:2501 个(去重后)与其他文件重复:2520 个中文内部重复:0 个总跳过数量:2520 个索引区间:[840, 3340]
--------------------------------------------------
所有被跳过的字符已保存到:skip.txt(共 2731 个)
词表总字符数:3341 个
词表已保存为:vocab_3.json
==================================================
修改model模型文件
vocab.json
换成新的同名的词表文件
config.json
需要修改词表大小:
"vocab_size": 3340
generation_config.json
文件不用修改,因为四个特殊token的id都没有变。
tokenizer_config.json
文件不用修改
{"_from_model_config": true,"bos_token_id": 0,"decoder_start_token_id": 2,"eos_token_id": 2,"pad_token_id": 1,"transformers_version": "4.27.0.dev0","use_cache": false
}
权重文件model.safetensors
TrOCR(基于 Transformer 的 OCR 模型)的核心组件(如解码器的嵌入层和输出层)的维度与词表大小强相关:
-
嵌入层(Embedding Layer):输入维度为vocab_size,输出维度为模型隐藏层大小(如 768)。预训练模型的嵌入层权重是基于原词表的,若新词表大小(3340)与原词表不同,嵌入层的参数维度会不匹配,直接加载会报错。
- 词嵌入层实际路径是 model.decoder.model.decoder.embed_tokens(嵌套在 model→decoder 层级下)
-
输出层(Output Layer):输出维度为vocab_size(用于预测每个字符的概率),预训练的输出层权重同样依赖原词表大小,新词表大小变化后,输出层参数维度也会不匹配。
- 输出层实际路径是 model.decoder.output_projection(直接在解码器下)