Skip to content

关于resume_from_checkpoint加载deepspeed #4765

Open
@MitsuiChen14

Description

@MitsuiChen14

我在使用以下脚本去断点续训的时候,出现如下错误:

MASTER_PORT=29999 \
swift sft \
 --train_type full \
 --dataset TinyChart_train_url_messages.jsonl \
 					 elem-v2-train.jsonl \
           tikuquery-rewriteqa-train.jsonl \
           wendang-rewriteqa-train.jsonl \
 --torch_dtype bfloat16 \
 --attn_impl flash_attn \
 --dataset_num_proc 32 \
 --num_train_epochs 1 \
 --max_length 4096 \
 --truncation_strategy 'delete' \
 --per_device_train_batch_size 1 \
 --learning_rate 1e-5 \
 --gradient_accumulation_steps 8 \
 --save_steps 500 \
 --logging_steps 5 \
 --output_dir ${PRIMUS_SAVE_CHECKPOINT_DIR} \
 --system 'You are a helpful assistant.' \
 --warmup_ratio 0.05 \
 --dataloader_num_workers 4 \
 --gradient_checkpointing False \
 --freeze_vit False \
 --freeze_aligner False \
 --eval_strategy no \
 --report_to tensorboard \
 --logging_dir ${PRIMUS_TENSORBOARD_LOG_DIR} \
 --deepspeed zero3_offload \
 --resume_from_checkpoint ${continue_path} 
[rank1]: Traceback (most recent call last):
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/cli/sft.py", line 7, in <module>
[rank1]:     sft_main()
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 273, in sft_main
[rank1]:     return SwiftSft(args).main()
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/base.py", line 49, in main
[rank1]:     result = self.run()
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 129, in run
[rank1]:     return self.train(trainer)
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 189, in train
[rank1]:     trainer.train(trainer.args.resume_from_checkpoint)
[rank1]:   File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/trainers/mixin.py", line 369, in train
[rank1]:     res = super().train(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2245, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2398, in _inner_training_loop
[rank1]:     deepspeed_load_checkpoint(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformers/integrations/deepspeed.py", line 489, in deepspeed_load_checkpoint
[rank1]:     load_path, _ = deepspeed_engine.load_checkpoint(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2980, in load_checkpoint
[rank1]:     load_path, client_states = self._load_checkpoint(load_dir,
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3063, in _load_checkpoint
[rank1]:     self.load_module_state_dict(checkpoint=checkpoint,
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2840, in load_module_state_dict
[rank1]:     self.module.load_state_dict(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
[rank1]:     raise RuntimeError(
[rank1]: RuntimeError: Error(s) in loading state_dict for Qwen2_5_VLForConditionalGeneration:
[rank1]: 	Missing key(s) in state_dict: "visual.patch_embed.proj.weight",... ...
[rank1]: 	Unexpected key(s) in state_dict: "model.visual.patch_embed.proj.weight",... ...

看起来是在加载deepspeed保存的中间状态时,checkpoint的key之前多了一个model.,我不太确定这个问题是出现在哪里。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions