@@ -226,21 +226,20 @@ def _encode_dataset(self, train_dataset, val_dataset):
226
226
template = self .template
227
227
args = self .args
228
228
is_grpo = hasattr (args , 'rlhf_type' ) and args .rlhf_type == 'grpo'
229
- if args .lazy_tokenize :
230
- train_dataset = LazyLLMDataset (
231
- train_dataset , template .encode , strict = args .strict , random_state = args .data_seed )
232
- if val_dataset is not None and not args .predict_with_generate :
233
- val_dataset = LazyLLMDataset (
234
- val_dataset , template .encode , strict = args .strict , random_state = args .data_seed )
235
- elif is_grpo :
236
- pass
237
- else :
238
- preprocessor_cls = PackingPreprocessor if args .packing else EncodePreprocessor
239
- preprocessor = preprocessor_cls (template = template )
240
- train_dataset = preprocessor (train_dataset , num_proc = args .dataset_num_proc , strict = args .strict )
241
- if val_dataset is not None and not args .predict_with_generate :
242
- val_dataset = preprocessor (val_dataset , num_proc = args .dataset_num_proc , strict = args .strict )
243
229
if not is_grpo :
230
+ if args .lazy_tokenize :
231
+ train_dataset = LazyLLMDataset (
232
+ train_dataset , template .encode , strict = args .strict , random_state = args .data_seed )
233
+ if val_dataset is not None and not args .predict_with_generate :
234
+ val_dataset = LazyLLMDataset (
235
+ val_dataset , template .encode , strict = args .strict , random_state = args .data_seed )
236
+ else :
237
+ preprocessor_cls = PackingPreprocessor if args .packing else EncodePreprocessor
238
+ preprocessor = preprocessor_cls (template = template )
239
+ train_dataset = preprocessor (train_dataset , num_proc = args .dataset_num_proc , strict = args .strict )
240
+ if val_dataset is not None and not args .predict_with_generate :
241
+ val_dataset = preprocessor (val_dataset , num_proc = args .dataset_num_proc , strict = args .strict )
242
+
244
243
inputs = train_dataset [0 ] if hasattr (train_dataset , '__len__' ) else next (iter (train_dataset ))
245
244
template .print_inputs (inputs , tokenizer_kwargs = inputs .pop ('tokenizer_kwargs' , None ) or {})
246
245
if isinstance (train_dataset , HfDataset ):
0 commit comments