Skip to content

Commit 6b320d2

Browse files
committed
update
1 parent bebd903 commit 6b320d2

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

swift/llm/train/sft.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,20 @@ def _encode_dataset(self, train_dataset, val_dataset):
226226
template = self.template
227227
args = self.args
228228
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
229-
if args.lazy_tokenize:
230-
train_dataset = LazyLLMDataset(
231-
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
232-
if val_dataset is not None and not args.predict_with_generate:
233-
val_dataset = LazyLLMDataset(
234-
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
235-
elif is_grpo:
236-
pass
237-
else:
238-
preprocessor_cls = PackingPreprocessor if args.packing else EncodePreprocessor
239-
preprocessor = preprocessor_cls(template=template)
240-
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
241-
if val_dataset is not None and not args.predict_with_generate:
242-
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
243229
if not is_grpo:
230+
if args.lazy_tokenize:
231+
train_dataset = LazyLLMDataset(
232+
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
233+
if val_dataset is not None and not args.predict_with_generate:
234+
val_dataset = LazyLLMDataset(
235+
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
236+
else:
237+
preprocessor_cls = PackingPreprocessor if args.packing else EncodePreprocessor
238+
preprocessor = preprocessor_cls(template=template)
239+
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
240+
if val_dataset is not None and not args.predict_with_generate:
241+
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
242+
244243
inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset))
245244
template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {})
246245
if isinstance(train_dataset, HfDataset):

swift/trainers/rlhf_arguments.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from dataclasses import dataclass, field
2-
from typing import List
1+
from dataclasses import dataclass
32

43
from trl import CPOConfig as HfCPOConfig
54
from trl import DPOConfig as HfDPOConfig

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
# Part of the implementation is borrowed from huggingface/trl.
33
from collections import defaultdict
4-
from typing import Any, Callable, Dict, List, Optional, Union
4+
from typing import Any, Dict, List, Optional, Union
55
from unittest.mock import patch
66

77
import torch
88
import torch.nn as nn
99
from accelerate.utils import broadcast_object_list, gather, gather_object
1010
from accelerate.utils.other import is_compiled_module
11-
from transformers import GenerationConfig, PreTrainedModel
11+
from transformers import PreTrainedModel
1212
from trl import GRPOTrainer as HFGRPOTrainer
1313
from trl.models import unwrap_model_for_generation
14-
from trl.trainer.utils import pad
1514

1615
from swift.llm import InferRequest, RequestConfig, to_device
1716
from swift.plugin.orm import orms
@@ -102,7 +101,7 @@ def __init__(self,
102101
logger.warning(
103102
f'The requested device {vllm_device} is also used for training. This may lead to unexpected '
104103
'behavior. It is recommended to use a dedicated device for vLLM.')
105-
from swift.llm import VllmEngine # , PtEngine
104+
from swift.llm import VllmEngine
106105
world_size_patch = patch('torch.distributed.get_world_size', return_value=1)
107106
profiling_patch = patch(
108107
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)

0 commit comments

Comments
 (0)