Open
Description
我在使用以下脚本去断点续训的时候,出现如下错误:
MASTER_PORT=29999 \
swift sft \
--train_type full \
--dataset TinyChart_train_url_messages.jsonl \
elem-v2-train.jsonl \
tikuquery-rewriteqa-train.jsonl \
wendang-rewriteqa-train.jsonl \
--torch_dtype bfloat16 \
--attn_impl flash_attn \
--dataset_num_proc 32 \
--num_train_epochs 1 \
--max_length 4096 \
--truncation_strategy 'delete' \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 8 \
--save_steps 500 \
--logging_steps 5 \
--output_dir ${PRIMUS_SAVE_CHECKPOINT_DIR} \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--gradient_checkpointing False \
--freeze_vit False \
--freeze_aligner False \
--eval_strategy no \
--report_to tensorboard \
--logging_dir ${PRIMUS_TENSORBOARD_LOG_DIR} \
--deepspeed zero3_offload \
--resume_from_checkpoint ${continue_path}
[rank1]: Traceback (most recent call last):
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/cli/sft.py", line 7, in <module>
[rank1]: sft_main()
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 273, in sft_main
[rank1]: return SwiftSft(args).main()
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/base.py", line 49, in main
[rank1]: result = self.run()
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 129, in run
[rank1]: return self.train(trainer)
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/llm/train/sft.py", line 189, in train
[rank1]: trainer.train(trainer.args.resume_from_checkpoint)
[rank1]: File "/root/code/MLLM_ChartQA_SFT_CWZ/swift/trainers/mixin.py", line 369, in train
[rank1]: res = super().train(*args, **kwargs)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2245, in train
[rank1]: return inner_training_loop(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2398, in _inner_training_loop
[rank1]: deepspeed_load_checkpoint(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/transformers/integrations/deepspeed.py", line 489, in deepspeed_load_checkpoint
[rank1]: load_path, _ = deepspeed_engine.load_checkpoint(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2980, in load_checkpoint
[rank1]: load_path, client_states = self._load_checkpoint(load_dir,
[rank1]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 3063, in _load_checkpoint
[rank1]: self.load_module_state_dict(checkpoint=checkpoint,
[rank1]: File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2840, in load_module_state_dict
[rank1]: self.module.load_state_dict(
[rank1]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
[rank1]: raise RuntimeError(
[rank1]: RuntimeError: Error(s) in loading state_dict for Qwen2_5_VLForConditionalGeneration:
[rank1]: Missing key(s) in state_dict: "visual.patch_embed.proj.weight",... ...
[rank1]: Unexpected key(s) in state_dict: "model.visual.patch_embed.proj.weight",... ...
看起来是在加载deepspeed保存的中间状态时,checkpoint的key之前多了一个model.,我不太确定这个问题是出现在哪里。
Metadata
Metadata
Assignees
Labels
No labels