Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Feb 9, 2025
1 parent bebd903 commit 6b320d2
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 20 deletions.
27 changes: 13 additions & 14 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,20 @@ def _encode_dataset(self, train_dataset, val_dataset):
template = self.template
args = self.args
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
if args.lazy_tokenize:
train_dataset = LazyLLMDataset(
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
if val_dataset is not None and not args.predict_with_generate:
val_dataset = LazyLLMDataset(
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
elif is_grpo:
pass
else:
preprocessor_cls = PackingPreprocessor if args.packing else EncodePreprocessor
preprocessor = preprocessor_cls(template=template)
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
if val_dataset is not None and not args.predict_with_generate:
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
if not is_grpo:
if args.lazy_tokenize:
train_dataset = LazyLLMDataset(
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
if val_dataset is not None and not args.predict_with_generate:
val_dataset = LazyLLMDataset(
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
else:
preprocessor_cls = PackingPreprocessor if args.packing else EncodePreprocessor
preprocessor = preprocessor_cls(template=template)
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
if val_dataset is not None and not args.predict_with_generate:
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)

inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset))
template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {})
if isinstance(train_dataset, HfDataset):
Expand Down
3 changes: 1 addition & 2 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from typing import List
from dataclasses import dataclass

from trl import CPOConfig as HfCPOConfig
from trl import DPOConfig as HfDPOConfig
Expand Down
7 changes: 3 additions & 4 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/trl.
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from unittest.mock import patch

import torch
import torch.nn as nn
from accelerate.utils import broadcast_object_list, gather, gather_object
from accelerate.utils.other import is_compiled_module
from transformers import GenerationConfig, PreTrainedModel
from transformers import PreTrainedModel
from trl import GRPOTrainer as HFGRPOTrainer
from trl.models import unwrap_model_for_generation
from trl.trainer.utils import pad

from swift.llm import InferRequest, RequestConfig, to_device
from swift.plugin.orm import orms
Expand Down Expand Up @@ -102,7 +101,7 @@ def __init__(self,
logger.warning(
f'The requested device {vllm_device} is also used for training. This may lead to unexpected '
'behavior. It is recommended to use a dedicated device for vLLM.')
from swift.llm import VllmEngine # , PtEngine
from swift.llm import VllmEngine
world_size_patch = patch('torch.distributed.get_world_size', return_value=1)
profiling_patch = patch(
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)
Expand Down

0 comments on commit 6b320d2

Please sign in to comment.