用 LoRA 对 Qwen2.5-VL 模型进行SFT - 训练哪些层
用 LoRA 对 Qwen2.5-VL 模型进行SFT - 训练哪些层
flyfish
可训练线性层
在 LoRA(Low-Rank Adaptation)微调中,lora_target=all
是一个关键配置选项,它允许系统自动识别并应用 LoRA 到模型中的所有合适线性层,无需手动指定具体模块。
1. 核心功能与工作流程
当配置文件中设置 lora_target: all
时,代码会执行以下操作:
-
参数解析:
lora_target: all
被解析为finetuning_args.lora_target = ["all"]
。 -
触发自动发现逻辑:
在_setup_lora_tuning
函数中:if finetuning_args.lora_target[0] == "all":target_modules = find_all_linear_modules(model, freeze_vision_tower)
-
模块扫描与筛选:
find_all_linear_modules
函数递归遍历模型的所有子模块,收集符合条件的线性层:- 必须是
nn.Linear
类型(排除嵌入层、归一化层等)。 - 不在禁止列表中(如
lm_head
、视觉塔模块等)。
- 必须是
2. 自动发现的实现细节
find_all_linear_modules
函数的核心逻辑:
def find_all_linear_modules(model, freeze_vision_tower):# 初始化禁止模块集合(根据模型类型动态调整)forbidden_modules = {"lm_head"} # 固定排除语言模型输出层if model_type == "chatglm":forbidden_modules.add("output_layer")if model_type in COMPOSITE_MODELS: # 多模态模型forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)if freeze_vision_tower:forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)# 遍历所有模块,筛选线性层module_names = set()for name, module in model.named_modules():if any(forbidden in name for forbidden in forbidden_modules):continue # 跳过禁止模块if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:module_names.add(name.split(".")[-1]) # 提取模块名(如 "q_proj")return list(module_names)
3. 典型输出与模块类型
对于 Qwen2.5-VL 模型,lora_target=all
通常会发现以下线性层:
- 注意力机制模块:
q_proj
:查询矩阵投影k_proj
:键矩阵投影v_proj
:值矩阵投影o_proj
:输出投影
- 前馈网络模块:
gate_proj
:门控线性单元up_proj
:上投影(扩展维度)down_proj
:下投影(收缩维度)
4. 与日志的对应关系
在训练日志中,可以看到自动发现的模块列表:
[INFO] Found linear modules: up_proj,v_proj,q_proj,o_proj,gate_proj,k_proj,down_proj
这正是 find_all_linear_modules
函数的输出结果,显示 LoRA 将应用于这些模块。
5. 适用场景
优势 | 适用场景 |
---|---|
简化配置 | 无需手动列举所有可训练模块,尤其适合复杂模型 |
全面微调 | 确保覆盖所有关键线性层,充分利用 LoRA 的表达能力 |
多模态模型适配 | 自动处理视觉塔冻结逻辑,保护预训练的视觉能力 |
模型架构无关性 | 支持不同模型架构(Qwen、Llama、ChatGLM 等),自动适应其模块命名规则 |
LoRA 微调中排除了视觉塔(visual)、嵌入层和输出层
默认冻结视觉塔、嵌入层和输出层是一种平衡训练效率与模型性能的策略,适用于大多数指令微调场景。如需针对特定场景优化,可选择性解冻这些模块。
一、为什么冻结视觉塔(Visual Tower)?
1. 保留预训练视觉能力
视觉塔(如 CLIP 编码器)通常在大规模图像数据上预训练,已具备强大的视觉特征提取能力。冻结这些模块可以:
避免灾难性遗忘:防止在微调过程中破坏已有的视觉理解能力。
减少训练负担:视觉塔参数占比大,冻结后可显著降低计算开销。
2. 多模态模型的训练策略
多模态模型的核心目标是对齐视觉和语言表示,而非重新学习视觉特征。因此:
通常仅微调跨模态交互层(如视觉-语言融合层)和语言生成模块。
视觉塔作为“特征提取器”,保持不变以提供稳定的视觉输入。
3. 数据限制
微调阶段的数据集(如指令遵循数据)通常包含较少的图像样本,不足以支持视觉塔的有效训练。若强行微调,可能导致过拟合或视觉性能下降。
二、为什么冻结嵌入层(embed_tokens)?
1. 语义稳定性
词嵌入矩阵存储了词汇的语义表示,微调可能导致:
语义漂移:模型在小数据集上学习到的词向量可能偏离预训练的通用语义。
训练不稳定:嵌入层参数众多(如 Qwen2.5-VL 的 151936×2048 矩阵),微调需要大量数据才能收敛。
2. 高效微调的权衡
LoRA 的设计理念是通过少量可训练参数实现高效微调。冻结嵌入层符合这一原则,将资源集中在更关键的线性变换层(如注意力机制)。
三、为什么冻结输出层(lm_head)?
1. 保持语言生成能力
lm_head
负责将模型隐藏状态映射到词汇表概率分布,冻结它可以:
保留预训练的语言知识:避免在微调过程中破坏基础语言生成能力。
减少过拟合风险:输出层直接影响生成文本的质量,微调可能导致模型在特定数据集上过拟合。
2. 任务适配方式
在指令微调中,通常通过调整中间层(如注意力权重)来学习新任务,而非直接修改输出层。例如:
模型通过调整注意力机制来更好地理解指令,而输出层仍使用预训练的词汇映射。
四、何时需要解冻这些模块?
虽然默认冻结这些模块,但在以下场景中可考虑解冻:
1. 视觉塔解冻
任务需求:若需增强模型对特定领域图像的理解(如医学影像、工业检测)。
充足数据:拥有大规模特定领域的图像-文本对数据。
显存支持:解冻视觉塔会显著增加可训练参数,需确保有足够显存。
2. 嵌入层解冻
领域适配:针对特定领域(如法律、医疗)微调,需要更新领域专属词汇的表示。
多语言任务:扩展模型到新语言时,可能需要微调嵌入层以适应新词汇。
3. 输出层解冻
词汇表修改:若新增了特殊词汇或需要调整生成偏好。
特定生成任务:如诗歌生成、代码生成等需要定制输出分布的场景。
当设置 lora_target: all
时的微调层
以下面的结构为例
Qwen2_5_VLForConditionalGeneration((visual): Qwen2_5_VisionTransformerPretrainedModel((patch_embed): Qwen2_5_VisionPatchEmbed((proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False))(rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding()(blocks): ModuleList((0-31): 32 x Qwen2_5_VLVisionBlock((norm1): Qwen2RMSNorm((0,), eps=1e-06)(norm2): Qwen2RMSNorm((0,), eps=1e-06)(attn): Qwen2_5_VLVisionFlashAttention2((qkv): Linear(in_features=1280, out_features=3840, bias=True)(proj): Linear(in_features=1280, out_features=1280, bias=True))(mlp): Qwen2_5_VLMLP((gate_proj): Linear(in_features=1280, out_features=3420, bias=True)(up_proj): Linear(in_features=1280, out_features=3420, bias=True)(down_proj): Linear(in_features=3420, out_features=1280, bias=True)(act_fn): SiLU())))(merger): Qwen2_5_VLPatchMerger((ln_q): Qwen2RMSNorm((0,), eps=1e-06)(mlp): Sequential((0): Linear(in_features=5120, out_features=5120, bias=True)(1): GELU(approximate='none')(2): Linear(in_features=5120, out_features=2048, bias=True))))(model): Qwen2_5_VLModel((embed_tokens): Embedding(151936, 2048)(layers): ModuleList((0-35): 36 x Qwen2_5_VLDecoderLayer((self_attn): Qwen2_5_VLFlashAttention2((q_proj): Linear(in_features=2048, out_features=2048, bias=True)(k_proj): Linear(in_features=2048, out_features=256, bias=True)(v_proj): Linear(in_features=2048, out_features=256, bias=True)(o_proj): Linear(in_features=2048, out_features=2048, bias=False)(rotary_emb): Qwen2_5_VLRotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=2048, out_features=11008, bias=False)(up_proj): Linear(in_features=2048, out_features=11008, bias=False)(down_proj): Linear(in_features=11008, out_features=2048, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm((0,), eps=1e-06)(post_attention_layernorm): Qwen2RMSNorm((0,), eps=1e-06)))(norm): Qwen2RMSNorm((0,), eps=1e-06)(rotary_emb): Qwen2_5_VLRotaryEmbedding())(lm_head): Linear(in_features=2048, out_features=151936, bias=False)
)
根据模型结构和代码逻辑,当配置 lora_target: all
且 freeze_vision_tower: true
时,LoRA 将微调以下层:
1. 语言模型部分(model
)
Transformer 层中的注意力模块
- 所有 36 个 DecoderLayer 中的注意力投影矩阵:
model.layers.[0-35].self_attn.q_proj model.layers.[0-35].self_attn.k_proj model.layers.[0-35].self_attn.v_proj model.layers.[0-35].self_attn.o_proj
Transformer 层中的 MLP 模块
- 所有 36 个 DecoderLayer 中的 MLP 线性层:
model.layers.[0-35].mlp.gate_proj model.layers.[0-35].mlp.up_proj model.layers.[0-35].mlp.down_proj
2. 不被微调的模块
视觉塔(visual
)
当 freeze_vision_tower: true
时,以下模块全部被冻结:
visual.patch_embed.proj
visual.blocks.[0-31].attn.qkv
和proj
visual.blocks.[0-31].mlp.gate_proj
、up_proj
、down_proj
visual.merger.mlp.0
和mlp.2
(视觉-语言融合层)
嵌入层和输出层
model.embed_tokens
(词嵌入)lm_head
(语言模型输出层)
3. 日志验证
训练日志中的关键输出:
target_modules:['up_proj', 'v_proj', 'q_proj', 'o_proj', 'gate_proj', 'k_proj', 'down_proj'].
Set vision model not trainable: ['visual.patch_embed', 'visual.blocks'].
Set multi model projector not trainable: visual.merger.
- 第一行:显示的模块名对应
model
中的注意力和 MLP 层。 - 第二行:确认视觉塔被冻结。
- 保留预训练视觉能力:视觉塔已在大规模图像数据上预训练,冻结可避免灾难性遗忘。
- 优化语言生成和跨模态对齐:通过微调语言模型的线性层,提升模型对视觉输入的理解和响应能力。
Qwen2.5-VL 使用 加粗 表示微调层,# 冻结 标注冻结模块
Qwen2_5_VLForConditionalGeneration((visual): Qwen2_5_VisionTransformerPretrainedModel((patch_embed): Qwen2_5_VisionPatchEmbed((proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False) # 冻结:视觉塔输入投影(有参数,停止更新))(rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding() # 无参数:旋转位置编码(纯计算,无参数)(blocks): ModuleList((0-31): 32 x Qwen2_5_VLVisionBlock((norm1): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:归一化层(有参数,停止更新)(norm2): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:归一化层(有参数,停止更新)(attn): Qwen2_5_VLVisionFlashAttention2((qkv): Linear(in_features=1280, out_features=3840, bias=True) # 冻结:视觉注意力投影(有参数,停止更新)(proj): Linear(in_features=1280, out_features=1280, bias=True) # 冻结:视觉注意力输出(有参数,停止更新))(mlp): Qwen2_5_VLMLP((gate_proj): Linear(in_features=1280, out_features=3420, bias=True) # 冻结:视觉MLP门控(有参数,停止更新)(up_proj): Linear(in_features=1280, out_features=3420, bias=True) # 冻结:视觉MLP上投影(有参数,停止更新)(down_proj): Linear(in_features=3420, out_features=1280, bias=True) # 冻结:视觉MLP下投影(有参数,停止更新)(act_fn): SiLU() # 无参数:激活函数(纯运算,无参数))))(merger): Qwen2_5_VLPatchMerger((ln_q): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:归一化层(有参数,停止更新)(mlp): Sequential((0): Linear(in_features=5120, out_features=5120, bias=True) # 冻结:视觉-语言融合投影(有参数,停止更新)(1): GELU(approximate='none') # 无参数:激活函数(纯运算,无参数)(2): Linear(in_features=5120, out_features=2048, bias=True) # 冻结:视觉-语言融合输出(有参数,停止更新))))(model): Qwen2_5_VLModel((embed_tokens): Embedding(151936, 2048) # 冻结:词嵌入(有参数,停止更新)(layers): ModuleList((0-35): 36 x Qwen2_5_VLDecoderLayer((self_attn): Qwen2_5_VLFlashAttention2(**(q_proj): Linear(in_features=2048, out_features=2048, bias=True)** # 微调:查询投影(有参数,参与更新)**(k_proj): Linear(in_features=2048, out_features=256, bias=True)** # 微调:键投影(有参数,参与更新)**(v_proj): Linear(in_features=2048, out_features=256, bias=True)** # 微调:值投影(有参数,参与更新)**(o_proj): Linear(in_features=2048, out_features=2048, bias=False)** # 微调:注意力输出(有参数,参与更新)(rotary_emb): Qwen2_5_VLRotaryEmbedding() # 无参数:旋转位置编码(纯计算,无参数))(mlp): Qwen2MLP(**(gate_proj): Linear(in_features=2048, out_features=11008, bias=False)** # 微调:MLP门控(有参数,参与更新)**(up_proj): Linear(in_features=2048, out_features=11008, bias=False)** # 微调:MLP上投影(有参数,参与更新)**(down_proj): Linear(in_features=11008, out_features=2048, bias=False)** # 微调:MLP下投影(有参数,参与更新)(act_fn): SiLU() # 无参数:激活函数(纯运算,无参数))(input_layernorm): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:输入归一化(有参数,停止更新)(post_attention_layernorm): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:注意力后归一化(有参数,停止更新)))(norm): Qwen2RMSNorm((0,), eps=1e-06) # 冻结:最终归一化(有参数,停止更新)(rotary_emb): Qwen2_5_VLRotaryEmbedding() # 无参数:旋转位置编码(纯计算,无参数))(lm_head): Linear(in_features=2048, out_features=151936, bias=False) # 冻结:语言模型输出层(有参数,停止更新)
)
参数性质分类
模块类型 | 示例模块 | 参数性质 | 微调时处理方式 |
---|---|---|---|
线性层/卷积层 | q_proj、proj、embed_tokens | 有参数 | 微调(更新)或冻结 |
归一化层 | norm1、ln_q | 有参数(如缩放因子) | 通常冻结 |
激活函数 | act_fn、GELU | 无参数 | 无需处理(运算固定) |
位置编码 | rotary_pos_emb | 无参数(纯计算) | 无需处理 |
代码参考1
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:r"""Find all available modules to apply LoRA, GaLore or APOLLO."""model_type = getattr(model.config, "model_type", None)forbidden_modules = {"lm_head"}if model_type == "chatglm":forbidden_modules.add("output_layer")elif model_type == "internlm2":forbidden_modules.add("output")if model_type in COMPOSITE_MODELS:forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)if freeze_vision_tower and model_type in COMPOSITE_MODELS:forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)module_names = set()for name, module in model.named_modules():if any(forbidden_module in name for forbidden_module in forbidden_modules):continueif "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:module_names.add(name.split(".")[-1])logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))return list(module_names)
代码参考2
def _setup_freeze_tuning(model: "PreTrainedModel",finetuning_args: "FinetuningArguments",is_trainable: bool,cast_trainable_params_to_fp32: bool,
) -> None:if not is_trainable:returnlogger.info_rank0("Fine-tuning method: Freeze")if hasattr(model.config, "text_config"): # composite modelsconfig = getattr(model.config, "text_config")else:config = model.confignum_layers = (getattr(config, "num_hidden_layers", None)or getattr(config, "num_layers", None)or getattr(config, "n_layer", None))if not num_layers:raise ValueError("Current model does not support freeze tuning.")if finetuning_args.use_llama_pro:if num_layers % finetuning_args.freeze_trainable_layers != 0:raise ValueError(f"`num_layers` {num_layers} should be "f"divisible by `num_layer_trainable` {finetuning_args.freeze_trainable_layers}.")stride = num_layers // finetuning_args.freeze_trainable_layerstrainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)else: # fine-tuning the first n layers if num_layer_trainable < 0trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))hidden_modules = set()non_hidden_modules = set()for name, _ in model.named_parameters():if ".0." in name:hidden_modules.add(name.split(".0.")[-1].split(".")[0])elif ".1." in name: # MoD starts from layer 1hidden_modules.add(name.split(".1.")[-1].split(".")[0])if re.search(r"\.\d+\.", name) is None:non_hidden_modules.add(name.split(".")[-2]) # remove weight/biastrainable_layers = []for module_name in finetuning_args.freeze_trainable_modules:if module_name != "all" and module_name not in hidden_modules:raise ValueError("Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)))for idx in trainable_layer_ids:trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))if finetuning_args.freeze_extra_modules:for module_name in finetuning_args.freeze_extra_modules:if module_name not in non_hidden_modules:raise ValueError("Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules)))trainable_layers.append(module_name)model_type = getattr(model.config, "model_type", None)if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)forbidden_modules = get_forbidden_modules(model.config, finetuning_args)for name, param in model.named_parameters():if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(forbidden_module in name for forbidden_module in forbidden_modules):if cast_trainable_params_to_fp32:param.data = param.data.to(torch.float32)else:param.requires_grad_(False)logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))def _setup_lora_tuning(config: "PretrainedConfig",model: "PreTrainedModel",model_args: "ModelArguments",finetuning_args: "FinetuningArguments",is_trainable: bool,cast_trainable_params_to_fp32: bool,
) -> "PeftModel":if is_trainable:logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))adapter_to_resume = Noneif model_args.adapter_name_or_path is not None:is_mergeable = Trueif getattr(model, "quantization_method", None): # merge lora in quantized model is unstableassert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."is_mergeable = Falseif is_deepspeed_zero3_enabled():assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."is_mergeable = Falseif model_args.use_unsloth:assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."is_mergeable = Falseif (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):adapter_to_merge = model_args.adapter_name_or_path[:-1]adapter_to_resume = model_args.adapter_name_or_path[-1]else:adapter_to_merge = model_args.adapter_name_or_pathinit_kwargs = {"subfolder": model_args.adapter_folder,"offload_folder": model_args.offload_folder,"cache_dir": model_args.cache_dir,"revision": model_args.model_revision,"token": model_args.hf_hub_token,}for adapter in adapter_to_merge:model: LoraModel = PeftModel.from_pretrained(model, adapter, **init_kwargs)model = model.merge_and_unload()if len(adapter_to_merge) > 0:logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")if adapter_to_resume is not None: # resume lora trainingif model_args.use_unsloth:model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)else:model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))if is_trainable and adapter_to_resume is None: # create new lora weights while trainingif len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)else:target_modules = finetuning_args.lora_targetlogger.info_rank0(f"target_modules:{target_modules}.")if finetuning_args.use_llama_pro:target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)target_modules = patch_target_modules(model, finetuning_args, target_modules)if (finetuning_args.use_doraand getattr(model, "quantization_method", None) is not Noneand getattr(model, "quantization_method", None) != QuantizationMethod.BNB):raise ValueError("DoRA is not compatible with PTQ-quantized models.")if model_args.resize_vocab and finetuning_args.additional_target is None:input_embeddings = model.get_input_embeddings()output_embeddings = model.get_output_embeddings()module_names = set()for name, module in model.named_modules():if module in [input_embeddings, output_embeddings]:module_names.add(name.split(".")[-1])finetuning_args.additional_target = module_nameslogger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))peft_kwargs = {"r": finetuning_args.lora_rank,"target_modules": target_modules,"lora_alpha": finetuning_args.lora_alpha,"lora_dropout": finetuning_args.lora_dropout,"use_rslora": finetuning_args.use_rslora,"use_dora": finetuning_args.use_dora,"modules_to_save": finetuning_args.additional_target,}if model_args.use_unsloth:model = get_unsloth_peft_model(model, model_args, peft_kwargs)else:if finetuning_args.pissa_init:if finetuning_args.pissa_iter == -1:logger.info_rank0("Using PiSSA initialization.")peft_kwargs["init_lora_weights"] = "pissa"else:logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,inference_mode=False,**peft_kwargs,)model = get_peft_model(model, lora_config)if is_trainable and cast_trainable_params_to_fp32:for param in filter(lambda p: p.requires_grad, model.parameters()):param.data = param.data.to(torch.float32)return model