Fix ORPO multi gpu (#1433)
Browse files* don't drop attention_mask for orpo
* handle multi-gpu cases better for orpo
* revert change to not drop the attention_mask from inputs for orpo
src/axolotl/core/trainer_builder.py
CHANGED
@@ -30,6 +30,7 @@ from transformers import (
|
|
30 |
from transformers.trainer_utils import seed_worker
|
31 |
from transformers.utils import is_sagemaker_mp_enabled
|
32 |
from trl import DPOTrainer
|
|
|
33 |
|
34 |
from axolotl.loraplus import create_loraplus_optimizer
|
35 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
@@ -472,6 +473,58 @@ class AxolotlTrainer(Trainer):
|
|
472 |
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
473 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
def orpo_compute_custom_loss(self, logits, labels):
|
476 |
logits = logits.contiguous()
|
477 |
loss = 0.0
|
@@ -512,45 +565,46 @@ class AxolotlTrainer(Trainer):
|
|
512 |
dim=2,
|
513 |
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
514 |
).squeeze(2)
|
515 |
-
return torch.mul(per_token_logps, mask.
|
516 |
-
dtype=torch.float64
|
517 |
-
) / mask.sum(dim=1).to(dtype=torch.float64)
|
518 |
|
519 |
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
},
|
526 |
-
output_hidden_states=True,
|
527 |
)
|
528 |
-
|
|
|
|
|
529 |
**{
|
530 |
-
"input_ids":
|
531 |
-
"attention_mask":
|
532 |
-
"labels":
|
533 |
},
|
534 |
output_hidden_states=True,
|
535 |
)
|
536 |
|
|
|
|
|
|
|
537 |
# Calculate NLL loss
|
538 |
pos_loss = self.orpo_compute_custom_loss(
|
539 |
-
logits=outputs_pos
|
540 |
)
|
541 |
|
542 |
# Calculate Log Probability
|
543 |
pos_prob = self.orpo_compute_logps(
|
544 |
-
prompt_attention_mask=
|
545 |
-
chosen_inputs=
|
546 |
-
chosen_attention_mask=
|
547 |
-
logits=outputs_pos
|
548 |
)
|
549 |
neg_prob = self.orpo_compute_logps(
|
550 |
-
prompt_attention_mask=
|
551 |
-
chosen_inputs=
|
552 |
-
chosen_attention_mask=
|
553 |
-
logits=outputs_neg
|
554 |
)
|
555 |
|
556 |
# Calculate log odds
|
@@ -1247,6 +1301,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
1247 |
train_dataset=self.train_dataset,
|
1248 |
eval_dataset=self.eval_dataset,
|
1249 |
args=training_args,
|
|
|
1250 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
1251 |
eval_data_collator=self.build_collator(
|
1252 |
training_args, is_eval=True, **data_collator_kwargs
|
|
|
30 |
from transformers.trainer_utils import seed_worker
|
31 |
from transformers.utils import is_sagemaker_mp_enabled
|
32 |
from trl import DPOTrainer
|
33 |
+
from trl.trainer.utils import pad_to_length
|
34 |
|
35 |
from axolotl.loraplus import create_loraplus_optimizer
|
36 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
|
|
473 |
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
474 |
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
475 |
|
476 |
+
@staticmethod
|
477 |
+
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
478 |
+
concatenated_batch = {}
|
479 |
+
|
480 |
+
max_length = max(
|
481 |
+
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
|
482 |
+
)
|
483 |
+
# Concatenate positive and negative inputs
|
484 |
+
concatenated_batch["input_ids"] = pad_to_length(
|
485 |
+
inputs["input_ids"], max_length, pad_token
|
486 |
+
)
|
487 |
+
concatenated_batch["rejected_input_ids"] = pad_to_length(
|
488 |
+
inputs["rejected_input_ids"], max_length, pad_token
|
489 |
+
)
|
490 |
+
concatenated_batch["labels"] = pad_to_length(
|
491 |
+
inputs["labels"], max_length, label_pad_token
|
492 |
+
)
|
493 |
+
concatenated_batch["rejected_labels"] = pad_to_length(
|
494 |
+
inputs["rejected_labels"], max_length, label_pad_token
|
495 |
+
)
|
496 |
+
concatenated_batch["attention_mask"] = pad_to_length(
|
497 |
+
inputs["attention_mask"], max_length, 0
|
498 |
+
)
|
499 |
+
concatenated_batch["rejected_attention_mask"] = pad_to_length(
|
500 |
+
inputs["rejected_attention_mask"], max_length, 0
|
501 |
+
)
|
502 |
+
concatenated_batch["prompt_attention_mask"] = pad_to_length(
|
503 |
+
inputs["prompt_attention_mask"], max_length, 0
|
504 |
+
).to(device=device)
|
505 |
+
|
506 |
+
input_ids = torch.cat(
|
507 |
+
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
|
508 |
+
dim=0,
|
509 |
+
).to(device=device)
|
510 |
+
attention_mask = torch.cat(
|
511 |
+
[
|
512 |
+
concatenated_batch["attention_mask"],
|
513 |
+
concatenated_batch["rejected_attention_mask"],
|
514 |
+
],
|
515 |
+
dim=0,
|
516 |
+
).to(device=device)
|
517 |
+
labels = torch.cat(
|
518 |
+
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
|
519 |
+
).to(device=device)
|
520 |
+
|
521 |
+
return {
|
522 |
+
"input_ids": input_ids,
|
523 |
+
"labels": labels,
|
524 |
+
"attention_mask": attention_mask,
|
525 |
+
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
|
526 |
+
}
|
527 |
+
|
528 |
def orpo_compute_custom_loss(self, logits, labels):
|
529 |
logits = logits.contiguous()
|
530 |
loss = 0.0
|
|
|
565 |
dim=2,
|
566 |
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
|
567 |
).squeeze(2)
|
568 |
+
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
|
|
|
|
569 |
|
570 |
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
571 |
+
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
572 |
+
inputs,
|
573 |
+
label_pad_token=-100,
|
574 |
+
pad_token=self.tokenizer.pad_token_id,
|
575 |
+
device=self.accelerator.device,
|
|
|
|
|
576 |
)
|
577 |
+
|
578 |
+
# Perform a single forward pass
|
579 |
+
outputs = model(
|
580 |
**{
|
581 |
+
"input_ids": concat_inputs["input_ids"],
|
582 |
+
"attention_mask": concat_inputs["attention_mask"],
|
583 |
+
"labels": concat_inputs["labels"],
|
584 |
},
|
585 |
output_hidden_states=True,
|
586 |
)
|
587 |
|
588 |
+
# Split the outputs for positive and negative examples
|
589 |
+
outputs_pos, outputs_neg = outputs.logits.chunk(2)
|
590 |
+
|
591 |
# Calculate NLL loss
|
592 |
pos_loss = self.orpo_compute_custom_loss(
|
593 |
+
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
|
594 |
)
|
595 |
|
596 |
# Calculate Log Probability
|
597 |
pos_prob = self.orpo_compute_logps(
|
598 |
+
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
599 |
+
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
|
600 |
+
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
|
601 |
+
logits=outputs_pos,
|
602 |
)
|
603 |
neg_prob = self.orpo_compute_logps(
|
604 |
+
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
|
605 |
+
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
|
606 |
+
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
|
607 |
+
logits=outputs_neg,
|
608 |
)
|
609 |
|
610 |
# Calculate log odds
|
|
|
1301 |
train_dataset=self.train_dataset,
|
1302 |
eval_dataset=self.eval_dataset,
|
1303 |
args=training_args,
|
1304 |
+
tokenizer=self.tokenizer,
|
1305 |
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
1306 |
eval_data_collator=self.build_collator(
|
1307 |
training_args, is_eval=True, **data_collator_kwargs
|