【踩坑记录】transformers 加载 checkpoint 继续训练
代码:
加上 transformers.utils.logging.set_verbosity_info()
可以显式详细的加载信息,比如下面:
保存后的文件格式:
可能报错:错误也比较好理解
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/code/grpo/my_grpo.py", line 76, in <module>
[rank0]: main(training_args, model_args, custom_args)
[rank0]: File "/mnt/code/grpo/my_grpo.py", line 67, in main
[rank0]: trainer.train(resume_from_checkpoint=True)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2245, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2534, in _inner_training_loop
[rank0]: self._load_rng_state(resume_from_checkpoint)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3130, in _load_rng_state
[rank0]: checkpoint_rng_state = torch.load(rng_file, weights_only=True)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/serialization.py", line 1359, in load
[rank0]: raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
[rank0]: _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
[rank0]: (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
[rank0]: (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
[rank0]: WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([_reconstruct])` to allowlist this global if you trust this class/function.[rank0]: Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
解决方案:
把 /opt/conda/lib/python3.11/site-packages/transformers/trainer.py
中的 checkpoint_rng_state = torch.load(rng_file, weights_only=True)
改为 weights_only=True
加载 checkpoint 需要配置的是:
TrainingArguments
中的output_dir
路径,也就是上图中traing_args
中的output_dir
,路径应该为/1b-new
。设置trainer.train(resume_from_checkpoint=True)
的话默认从output_dir
中加载,或者可以传入一个 ckpt 路径。GRPOTrainer
中的model_name_or_path
路径设置为你想要继续训练的 checkpoint,比如 checkpoint-50,则设置为/1b-new/checkpoint-50
存在一点问题:
设置了 use_reentrant: false
也没能保证继续训练时的 gard_norm 和之前一样,应该是没有完全恢复到之前的训练状态,但是最起码 learning rate、step 是恢复了,数据集的进度不知道有没有恢复。