当前位置: 首页 > news >正文

【踩坑记录】transformers 加载 checkpoint 继续训练

代码:
在这里插入图片描述
加上 transformers.utils.logging.set_verbosity_info() 可以显式详细的加载信息,比如下面:

在这里插入图片描述

保存后的文件格式:
在这里插入图片描述
可能报错:错误也比较好理解

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/code/grpo/my_grpo.py", line 76, in <module>
[rank0]:     main(training_args, model_args, custom_args)
[rank0]:   File "/mnt/code/grpo/my_grpo.py", line 67, in main
[rank0]:     trainer.train(resume_from_checkpoint=True)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2245, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
[rank0]:     self._load_rng_state(resume_from_checkpoint)
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3130, in _load_rng_state
[rank0]:     checkpoint_rng_state = torch.load(rng_file, weights_only=True)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1359, in load
[rank0]:     raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
[rank0]: _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. 
[rank0]: 	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
[rank0]: 	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
[rank0]: 	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([_reconstruct])` to allowlist this global if you trust this class/function.[rank0]: Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

解决方案:
/opt/conda/lib/python3.11/site-packages/transformers/trainer.py 中的 checkpoint_rng_state = torch.load(rng_file, weights_only=True) 改为 weights_only=True

加载 checkpoint 需要配置的是:

  • TrainingArguments 中的 output_dir 路径,也就是上图中 traing_args 中的 output_dir,路径应该为 /1b-new。设置 trainer.train(resume_from_checkpoint=True) 的话默认从 output_dir 中加载,或者可以传入一个 ckpt 路径。
  • GRPOTrainer 中的 model_name_or_path 路径设置为你想要继续训练的 checkpoint,比如 checkpoint-50,则设置为 /1b-new/checkpoint-50

存在一点问题:
设置了 use_reentrant: false 也没能保证继续训练时的 gard_norm 和之前一样,应该是没有完全恢复到之前的训练状态,但是最起码 learning rate、step 是恢复了,数据集的进度不知道有没有恢复。

http://www.xdnf.cn/news/482257.html

相关文章:

  • 微信小程序:封装表格组件并引用
  • 多模态大语言模型arxiv论文略读(七十九)
  • 每日算法刷题Day8 5.15:leetcode滑动窗口4道题,用时1h
  • COMSOL随机参数化表面流体流动模拟
  • linux 服务器安装jira-8.22.0和confluence-8.5.21
  • rinetd 实现通过访问主机访问虚拟机中的业务,调试虚拟机内的java进程
  • Qwen2.5-VL模型sft微调和使用vllm部署
  • TLS 1.3黑魔法:从协议破解到极致性能调优
  • 系统提示学习(System Prompt Learning)在医学编程中的初步分析与探索
  • 在Linux服务器上部署Jupyter Notebook并实现ssh无密码远程访问
  • 【Kubernetes】单Master集群部署(第二篇)
  • 15 C 语言字符类型详解:转义字符、格式化输出、字符类型本质、ASCII 码编程实战、最值宏汇总
  • 深度学习笔记23-LSTM实现火灾预测(Tensorflow)
  • Stratix 10 FPGA DDR4 选型
  • Visual Studio旧版直链
  • Elasticsearch 学习(一)如何在Linux 系统中下载、安装
  • 【简单模拟实现list】
  • 【PmHub后端篇】PmHub 中缓存与数据库一致性的实现方案及分析
  • c/c++的opencv的图像预处理讲解
  • 动态IP赋能业务增效:技术解构与实战应用指南
  • 1-10 目录树
  • 东方通2024年报分析:信创国产化龙头的蓬勃发展与未来可期
  • mysql的not exists走索引吗
  • uniapp-商城-60-后台 新增商品(属性的选中和页面显示)
  • MySQL——2、库的操作和表的操作
  • 割点与其例题
  • 管理工具导入CSV文件,中文数据乱码的解决办法。(APP)
  • 从类的外部访问静态成员:深入理解C#静态特性
  • C语言编程中的时间处理
  • 【学习笔记】机器学习(Machine Learning) | 第七章|神经网络(1)