Ali Mosavian
Ali Mosavian
commited on
FIX: TRL trainer preprocessing step was running in one process (#1583)
Browse files* FIX: TRL trainer preprocessing step was running in one process
* FIX: Changed so that dataset_num_proc is sent to CPO, KTO and ORPO trainer args and directly to the trainer when DPO
* FIX: Changed back to only support ORPO for now, since KTO is handled in another way
---------
Co-authored-by: Ali Mosavian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
@@ -1462,6 +1462,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1462 |
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
1463 |
else:
|
1464 |
training_args_kwargs["evaluation_strategy"] = "no"
|
|
|
1465 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
1466 |
training_args_kwargs["bf16"] = True
|
1467 |
|
@@ -1520,6 +1521,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1520 |
training_args_cls = TrainingArguments
|
1521 |
if self.cfg.rl == "orpo":
|
1522 |
training_args_cls = ORPOConfig
|
|
|
1523 |
|
1524 |
training_args = training_args_cls(
|
1525 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
@@ -1564,6 +1566,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1564 |
dpo_trainer_kwargs["max_target_length"] = None
|
1565 |
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
1566 |
dpo_trainer_kwargs["generate_during_eval"] = True
|
|
|
|
|
1567 |
elif self.cfg.rl == "orpo":
|
1568 |
trainer_cls = AxolotlORPOTrainer
|
1569 |
trainer_cls_args = [self.model]
|
|
|
1462 |
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
1463 |
else:
|
1464 |
training_args_kwargs["evaluation_strategy"] = "no"
|
1465 |
+
|
1466 |
if self.cfg.bf16 or self.cfg.bfloat16:
|
1467 |
training_args_kwargs["bf16"] = True
|
1468 |
|
|
|
1521 |
training_args_cls = TrainingArguments
|
1522 |
if self.cfg.rl == "orpo":
|
1523 |
training_args_cls = ORPOConfig
|
1524 |
+
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
1525 |
|
1526 |
training_args = training_args_cls(
|
1527 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
|
|
1566 |
dpo_trainer_kwargs["max_target_length"] = None
|
1567 |
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
1568 |
dpo_trainer_kwargs["generate_during_eval"] = True
|
1569 |
+
if self.cfg.rl == "dpo":
|
1570 |
+
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
1571 |
elif self.cfg.rl == "orpo":
|
1572 |
trainer_cls = AxolotlORPOTrainer
|
1573 |
trainer_cls_args = [self.model]
|