当前位置: 首页 > web >正文

【diffusers 进阶之 PEFT 入门(五)】LoRA 权重如何接着训练?踩坑总结和解决方案

系列文章目录

  • 【diffusers 进阶之 PEFT 入门(一)】 inject_adapter_in_model 详解
  • 【diffusers 进阶之 PEFT 入门(二)】LoraConfig 如何处理 lora_config 参数的?
  • 【diffusers 进阶之 PEFT 入门(三)】BaseTunerLayer 与 set_adapter
  • 【diffusers 进阶之 PEFT 入门(四)】load_lora_weight 的踩坑总结以及解决方案

文章目录

  • 系列文章目录
  • 深入理解 LoRA 权重加载:避免训练状态重置的陷阱
    • 问题现象
    • 问题分析
    • 解决方案
    • 技术要点
      • 深入 LoRA 加载机制
    • 结论


深入理解 LoRA 权重加载:避免训练状态重置的陷阱

在使用 LoRA (Low-Rank Adaptation) 进行模型微调时,我们经常需要从检查点恢复训练。然而,如果加载预训练权重的方式不当,可能会导致训练状态被意外重置。本文将分享一个实际案例,展示如何正确加载 LoRA 权重以继续训练。

问题现象

在我们的训练过程中,虽然正确加载了优化器状态,但模型的 loss 值却从高值重新开始:

# 期望的 loss 值(之前的训练状态)
Epoch: 2, Steps: 70070, Batch: 21771, Loss: 0.0898
Epoch: 2, Steps: 70080, Batch: 21781, Loss: 0.0868# 实际的 loss 值(意外重置)
Epoch: 0, Steps: 70010, Batch: 9, Loss: 0.3402
Epoch: 0, Steps: 70020, Batch: 19, Loss: 0.3638

问题分析

通过代码审查,我们发现问题出在 LoRA 初始化的顺序上。原始代码是这样的:

def init_lora(self, lora_path: str, lora_config: dict):if lora_path: # 1. 先添加新的 adapter(这会初始化新的权重!)self.transformer.add_adapter(LoraConfig(**lora_config))# 2. 然后尝试加载预训练权重(为时已晚)self.flux_pipe = load_lora(self.flux_pipe, lora_path, alpha=4.0,  # 硬编码的 alpha 值use_adapter_method=False)

这里有两个关键问题:

  1. 初始化顺序错误:先调用 add_adapter 会创建新的 LoRA 层并随机初始化权重,这会覆盖掉后续加载的预训练权重。
  2. 硬编码的 alpha 值:使用固定的 alpha=4.0,而不是配置文件中指定的值。

解决方案

修改后的代码如下:

def init_lora(self, lora_path: str, lora_config: dict):assert lora_path or lora_configif lora_path: # 1. 首先将所有参数设置为不可训练for param in self.transformer.parameters():param.requires_grad = False# 2. 直接加载预训练的 LoRA 权重self.flux_pipe = load_lora(self.flux_pipe, lora_path, alpha=lora_config["lora_alpha"],  # 使用配置文件中的 alpha 值use_adapter_method=True,          # 使用 PEFT 的 adapter 机制adapter_name="default"            # 指定 adapter 名称)print(f"已加载LoRA权重: {lora_path},使用 alpha={lora_config['lora_alpha']}")# 3. 获取需要训练的层lora_layers = filter(lambda p: p.requires_grad, self.transformer.parameters())   else:# 如果是新训练,才添加新的 adapterself.transformer.add_adapter(LoraConfig(**lora_config))lora_layers = filter(lambda p: p.requires_grad, self.transformer.parameters())return list(lora_layers)

关键改进:

  1. 正确的加载顺序:直接使用 PEFT 的 adapter 机制加载预训练权重,而不是先创建新的 adapter。
  2. 动态的 alpha 值:使用配置文件中指定的 lora_alpha 值,确保与训练配置一致。
  3. 清晰的状态管理:只在新训练时才创建新的 adapter,避免意外重置权重。

技术要点

  1. PEFT 的 adapter 机制:PEFT 库提供了标准的 adapter 加载机制,能够正确处理权重的加载和初始化。

  2. LoRA 配置一致性:确保加载时使用的配置(如 alpha 值)与训练时保持一致,这对模型行为的一致性很重要。

  3. 权重加载顺序:在深度学习框架中,操作顺序往往会影响最终结果。要特别注意初始化和加载的顺序。

您说得对!让我补充一下这个重要的技术细节。我将在博客中添加一个新的章节:

深入 LoRA 加载机制

在我们的解决方案中,实际的 LoRA 加载是通过 load_lora_utils_ours.py 中的代码实现的:

# 方法1: 使用适配器方法加载
if use_adapter_method:if not adapter_name:raise ValueError("使用适配器方法加载时必须提供adapter_name参数")try:# 关键实现:优先使用 load_lora_weights 方法if hasattr(pipe, 'load_lora_weights'):pipe.load_lora_weights(lora_path, adapter_name=adapter_name)logging.info(f"使用 load_lora_weights 方法加载 LoRA 权重,适配器名称: {adapter_name}")# 备选:使用 load_lora 方法elif hasattr(pipe, 'load_lora'):pipe.load_lora(lora_path)logging.info("使用 load_lora 方法加载 LoRA 权重")else:logging.warning("模型没有内置的 LoRA 加载方法,将切换为手动加载方式")use_adapter_method = Falseexcept Exception as e:logging.warning(f"使用适配器方法加载失败: {str(e)},将切换为手动加载方式")use_adapter_method = False

这段代码展示了一个优雅的加载策略:

  1. 优先使用标准方法

    • 首选 load_lora_weights 方法,这是 PEFT 库推荐的标准方式
    • 该方法会正确处理 adapter 的创建和权重加载
  2. 优雅的降级机制

    • 如果 load_lora_weights 不可用,尝试使用 load_lora 方法
    • 如果两种方法都不可用,会降级到手动加载模式
  3. 错误处理

    • 完整的异常处理确保即使标准加载失败,也能通过手动方式继续加载
    • 详细的日志输出帮助追踪加载过程

结论

这个案例展示了在处理模型权重加载时的一个常见陷阱。正确的加载顺序和配置对于确保模型能够正确地继续训练至关重要。通过使用 PEFT 的标准机制并注意初始化顺序,我们可以避免训练状态被意外重置的问题。

这个经验也提醒我们,在深度学习工程实践中,有时看似简单的操作顺序调整可能会对训练结果产生重大影响。保持良好的代码组织和清晰的状态管理可以帮助我们避免类似的问题。

http://www.xdnf.cn/news/3203.html

相关文章:

  • 在宝塔面板中安装OpenJDK-17的三种方法
  • K8S - 从零构建 Docker 镜像与容器
  • OpenCV 图形API(73)图像与通道拼接函数-----执行 查找表操作图像处理函数LUT()
  • AdaBoost算法的原理及Python实现
  • Vue ui初始化项目并使用iview写一个菜单导航
  • BUUCTF——Fakebook 1
  • UE 材质 条纹循环发光
  • Android compileSdkVersion、minSdkVersion、targetSdkVersion的关系以及和Unity的关系
  • Qwen3本地化部署,准备工作:SGLang
  • K8S - 从单机到集群 - 核心对象与实战解析
  • 同时启动俩个tomcat压缩版
  • C# 在VS2022中开发常用设置
  • Python 爬取微店商品列表接口(item_search)的实战指南
  • 如何在Windows上实现MacOS中的open命令
  • 网工_ICMP协议
  • Linux-04-用户管理命令
  • Java List分页工具
  • 排序算法——选择排序
  • 微格式:为Web内容赋予语义的力量
  • 【Linux 网络】网络工具ifconfig和iproute/iproute2工具详解
  • 端到端观测分析:从前端负载均衡到后端服务
  • 进程、线程、进程间通信Unix Domain Sockets (UDS)
  • 《操作系统真象还原》第十一章——用户进程
  • Spring 框架中的常见注解讲解
  • Qt窗口关闭特效:自底而上逐渐消失
  • google colab设置python环境为python3.7
  • 提高程序灵活性和效率的利器:Natasha动态编译库【.Net】
  • 【学习笔记】Shell编程--Bash变量
  • HBuider中Uniapp去除顶部导航栏-小程序、H5、APP适用
  • 线上婚恋相亲小程序源码介绍