winglian commited on
Commit
34ba634
·
unverified ·
1 Parent(s): 4e69aa4

Fix ORPO multi gpu (#1433)

Browse files

* don't drop attention_mask for orpo

* handle multi-gpu cases better for orpo

* revert change to not drop the attention_mask from inputs for orpo

Files changed (1) hide show
  1. src/axolotl/core/trainer_builder.py +78 -23
src/axolotl/core/trainer_builder.py CHANGED
@@ -30,6 +30,7 @@ from transformers import (
30
  from transformers.trainer_utils import seed_worker
31
  from transformers.utils import is_sagemaker_mp_enabled
32
  from trl import DPOTrainer
 
33
 
34
  from axolotl.loraplus import create_loraplus_optimizer
35
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
@@ -472,6 +473,58 @@ class AxolotlTrainer(Trainer):
472
  return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
473
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  def orpo_compute_custom_loss(self, logits, labels):
476
  logits = logits.contiguous()
477
  loss = 0.0
@@ -512,45 +565,46 @@ class AxolotlTrainer(Trainer):
512
  dim=2,
513
  index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
514
  ).squeeze(2)
515
- return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(
516
- dtype=torch.float64
517
- ) / mask.sum(dim=1).to(dtype=torch.float64)
518
 
519
  def orpo_compute_loss(self, model, inputs, return_outputs=False):
520
- outputs_neg = model(
521
- **{
522
- "input_ids": inputs["rejected_input_ids"],
523
- "attention_mask": inputs["rejected_attention_mask"],
524
- "labels": inputs["rejected_labels"],
525
- },
526
- output_hidden_states=True,
527
  )
528
- outputs_pos = model(
 
 
529
  **{
530
- "input_ids": inputs["input_ids"],
531
- "attention_mask": inputs["attention_mask"],
532
- "labels": inputs["labels"],
533
  },
534
  output_hidden_states=True,
535
  )
536
 
 
 
 
537
  # Calculate NLL loss
538
  pos_loss = self.orpo_compute_custom_loss(
539
- logits=outputs_pos.logits, labels=inputs["input_ids"]
540
  )
541
 
542
  # Calculate Log Probability
543
  pos_prob = self.orpo_compute_logps(
544
- prompt_attention_mask=inputs["prompt_attention_mask"],
545
- chosen_inputs=inputs["input_ids"],
546
- chosen_attention_mask=inputs["attention_mask"],
547
- logits=outputs_pos.logits,
548
  )
549
  neg_prob = self.orpo_compute_logps(
550
- prompt_attention_mask=inputs["prompt_attention_mask"],
551
- chosen_inputs=inputs["rejected_input_ids"],
552
- chosen_attention_mask=inputs["rejected_attention_mask"],
553
- logits=outputs_neg.logits,
554
  )
555
 
556
  # Calculate log odds
@@ -1247,6 +1301,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1247
  train_dataset=self.train_dataset,
1248
  eval_dataset=self.eval_dataset,
1249
  args=training_args,
 
1250
  data_collator=self.build_collator(training_args, **data_collator_kwargs),
1251
  eval_data_collator=self.build_collator(
1252
  training_args, is_eval=True, **data_collator_kwargs
 
30
  from transformers.trainer_utils import seed_worker
31
  from transformers.utils import is_sagemaker_mp_enabled
32
  from trl import DPOTrainer
33
+ from trl.trainer.utils import pad_to_length
34
 
35
  from axolotl.loraplus import create_loraplus_optimizer
36
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
 
473
  return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
474
  return super().compute_loss(model, inputs, return_outputs=return_outputs)
475
 
476
+ @staticmethod
477
+ def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
478
+ concatenated_batch = {}
479
+
480
+ max_length = max(
481
+ inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
482
+ )
483
+ # Concatenate positive and negative inputs
484
+ concatenated_batch["input_ids"] = pad_to_length(
485
+ inputs["input_ids"], max_length, pad_token
486
+ )
487
+ concatenated_batch["rejected_input_ids"] = pad_to_length(
488
+ inputs["rejected_input_ids"], max_length, pad_token
489
+ )
490
+ concatenated_batch["labels"] = pad_to_length(
491
+ inputs["labels"], max_length, label_pad_token
492
+ )
493
+ concatenated_batch["rejected_labels"] = pad_to_length(
494
+ inputs["rejected_labels"], max_length, label_pad_token
495
+ )
496
+ concatenated_batch["attention_mask"] = pad_to_length(
497
+ inputs["attention_mask"], max_length, 0
498
+ )
499
+ concatenated_batch["rejected_attention_mask"] = pad_to_length(
500
+ inputs["rejected_attention_mask"], max_length, 0
501
+ )
502
+ concatenated_batch["prompt_attention_mask"] = pad_to_length(
503
+ inputs["prompt_attention_mask"], max_length, 0
504
+ ).to(device=device)
505
+
506
+ input_ids = torch.cat(
507
+ [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
508
+ dim=0,
509
+ ).to(device=device)
510
+ attention_mask = torch.cat(
511
+ [
512
+ concatenated_batch["attention_mask"],
513
+ concatenated_batch["rejected_attention_mask"],
514
+ ],
515
+ dim=0,
516
+ ).to(device=device)
517
+ labels = torch.cat(
518
+ [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
519
+ ).to(device=device)
520
+
521
+ return {
522
+ "input_ids": input_ids,
523
+ "labels": labels,
524
+ "attention_mask": attention_mask,
525
+ "prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
526
+ }
527
+
528
  def orpo_compute_custom_loss(self, logits, labels):
529
  logits = logits.contiguous()
530
  loss = 0.0
 
565
  dim=2,
566
  index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
567
  ).squeeze(2)
568
+ return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
 
 
569
 
570
  def orpo_compute_loss(self, model, inputs, return_outputs=False):
571
+ concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
572
+ inputs,
573
+ label_pad_token=-100,
574
+ pad_token=self.tokenizer.pad_token_id,
575
+ device=self.accelerator.device,
 
 
576
  )
577
+
578
+ # Perform a single forward pass
579
+ outputs = model(
580
  **{
581
+ "input_ids": concat_inputs["input_ids"],
582
+ "attention_mask": concat_inputs["attention_mask"],
583
+ "labels": concat_inputs["labels"],
584
  },
585
  output_hidden_states=True,
586
  )
587
 
588
+ # Split the outputs for positive and negative examples
589
+ outputs_pos, outputs_neg = outputs.logits.chunk(2)
590
+
591
  # Calculate NLL loss
592
  pos_loss = self.orpo_compute_custom_loss(
593
+ logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
594
  )
595
 
596
  # Calculate Log Probability
597
  pos_prob = self.orpo_compute_logps(
598
+ prompt_attention_mask=concat_inputs["prompt_attention_mask"],
599
+ chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
600
+ chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
601
+ logits=outputs_pos,
602
  )
603
  neg_prob = self.orpo_compute_logps(
604
+ prompt_attention_mask=concat_inputs["prompt_attention_mask"],
605
+ chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
606
+ chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
607
+ logits=outputs_neg,
608
  )
609
 
610
  # Calculate log odds
 
1301
  train_dataset=self.train_dataset,
1302
  eval_dataset=self.eval_dataset,
1303
  args=training_args,
1304
+ tokenizer=self.tokenizer,
1305
  data_collator=self.build_collator(training_args, **data_collator_kwargs),
1306
  eval_data_collator=self.build_collator(
1307
  training_args, is_eval=True, **data_collator_kwargs