Ali Mosavian Ali Mosavian commited on
Commit
b9bb169
·
unverified ·
1 Parent(s): 601c08b

FIX: TRL trainer preprocessing step was running in one process (#1583)

Browse files

* FIX: TRL trainer preprocessing step was running in one process

* FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO

* FIX: Changed back to only support ORPO for now, since KTO is handled in another way

---------

Co-authored-by: Ali Mosavian <[email protected]>

src/axolotl/core/trainer_builder.py CHANGED
@@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1462
  training_args_kwargs["eval_steps"] = self.cfg.eval_steps
1463
  else:
1464
  training_args_kwargs["evaluation_strategy"] = "no"
 
1465
  if self.cfg.bf16 or self.cfg.bfloat16:
1466
  training_args_kwargs["bf16"] = True
1467
 
@@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1520
  training_args_cls = TrainingArguments
1521
  if self.cfg.rl == "orpo":
1522
  training_args_cls = ORPOConfig
 
1523
 
1524
  training_args = training_args_cls(
1525
  per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1564
  dpo_trainer_kwargs["max_target_length"] = None
1565
  dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
1566
  dpo_trainer_kwargs["generate_during_eval"] = True
 
 
1567
  elif self.cfg.rl == "orpo":
1568
  trainer_cls = AxolotlORPOTrainer
1569
  trainer_cls_args = [self.model]
 
1462
  training_args_kwargs["eval_steps"] = self.cfg.eval_steps
1463
  else:
1464
  training_args_kwargs["evaluation_strategy"] = "no"
1465
+
1466
  if self.cfg.bf16 or self.cfg.bfloat16:
1467
  training_args_kwargs["bf16"] = True
1468
 
 
1521
  training_args_cls = TrainingArguments
1522
  if self.cfg.rl == "orpo":
1523
  training_args_cls = ORPOConfig
1524
+ training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1525
 
1526
  training_args = training_args_cls(
1527
  per_device_train_batch_size=self.cfg.micro_batch_size,
 
1566
  dpo_trainer_kwargs["max_target_length"] = None
1567
  dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
1568
  dpo_trainer_kwargs["generate_during_eval"] = True
1569
+ if self.cfg.rl == "dpo":
1570
+ dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1571
  elif self.cfg.rl == "orpo":
1572
  trainer_cls = AxolotlORPOTrainer
1573
  trainer_cls_args = [self.model]