Text Generation
Transformers
Safetensors
lola_v1
custom_code
neo-nlp-dev commited on
Commit
06859d4
1 Parent(s): c2502f1

updating lola modeling class with auxiliary loss

Browse files
Files changed (1) hide show
  1. modeling_lola_gpt2.py +163 -217
modeling_lola_gpt2.py CHANGED
@@ -7,6 +7,11 @@
7
  import warnings
8
  from typing import Optional, Tuple, Union
9
 
 
 
 
 
 
10
  from .configuration_lola_gpt2 import LOLAConfig
11
  import torch
12
  import torch.utils.checkpoint
@@ -16,6 +21,7 @@ from torch.nn import CrossEntropyLoss
16
 
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPastAndCrossAttentions,
 
19
  SequenceClassifierOutputWithPast,
20
  QuestionAnsweringModelOutput
21
  )
@@ -27,11 +33,68 @@ from transformers.utils import (
27
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
28
 
29
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP, GPT2Block, GPT2PreTrainedModel
30
- from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification, GPT2ForTokenClassification
 
 
 
 
 
 
 
31
 
32
 
33
  logger = logging.get_logger(__name__)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # LOLA
36
  class LOLAModel(GPT2PreTrainedModel):
37
 
@@ -39,7 +102,9 @@ class LOLAModel(GPT2PreTrainedModel):
39
 
40
  def __init__(self, config):
41
  super().__init__(config)
42
-
 
 
43
  self.embed_dim = config.hidden_size
44
 
45
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
@@ -63,9 +128,9 @@ class LOLAModel(GPT2PreTrainedModel):
63
  def parallelize(self, device_map=None):
64
  # Check validity of device_map
65
  warnings.warn(
66
- "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
67
- " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
68
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
69
  " ...}",
70
  FutureWarning,
71
  )
@@ -89,7 +154,7 @@ class LOLAModel(GPT2PreTrainedModel):
89
 
90
  def deparallelize(self):
91
  warnings.warn(
92
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
93
  FutureWarning,
94
  )
95
  self.model_parallel = False
@@ -219,7 +284,7 @@ class LOLAModel(GPT2PreTrainedModel):
219
  if self.gradient_checkpointing and self.training:
220
  if use_cache:
221
  logger.warning_once(
222
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
223
  )
224
  use_cache = False
225
 
@@ -227,6 +292,7 @@ class LOLAModel(GPT2PreTrainedModel):
227
  all_self_attentions = () if output_attentions else None
228
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
229
  all_hidden_states = () if output_hidden_states else None
 
230
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
231
  # Model parallel
232
  if self.model_parallel:
@@ -269,11 +335,14 @@ class LOLAModel(GPT2PreTrainedModel):
269
  hidden_states = outputs[0]
270
  if use_cache is True:
271
  presents = presents + (outputs[1],)
 
 
 
 
 
272
 
273
  if output_attentions:
274
  all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
275
- if self.config.add_cross_attention:
276
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
277
 
278
  # Model Parallel: If it's the last layer for that device, put things on the next device
279
  if self.model_parallel:
@@ -284,23 +353,27 @@ class LOLAModel(GPT2PreTrainedModel):
284
  hidden_states = self.ln_f(hidden_states)
285
 
286
  hidden_states = hidden_states.view(output_shape)
 
 
 
 
 
287
  # Add last hidden state
288
  if output_hidden_states:
289
  all_hidden_states = all_hidden_states + (hidden_states,)
290
-
291
  if not return_dict:
292
- return tuple(
293
- v
294
- for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
295
- if v is not None
296
- )
297
 
298
- return BaseModelOutputWithPastAndCrossAttentions(
299
  last_hidden_state=hidden_states,
300
  past_key_values=presents,
301
  hidden_states=all_hidden_states,
302
  attentions=all_self_attentions,
303
- cross_attentions=all_cross_attentions,
 
304
  )
305
 
306
  class LOLABlock(nn.Module):
@@ -312,7 +385,6 @@ class LOLABlock(nn.Module):
312
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
313
  self.attn = GPT2Attention(config, layer_idx=layer_idx)
314
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
315
-
316
  self.moe = LOLAMOE(
317
  hidden_size,
318
  inner_dim,
@@ -336,7 +408,7 @@ class LOLABlock(nn.Module):
336
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
337
  use_cache: Optional[bool] = False,
338
  output_attentions: Optional[bool] = False,
339
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
340
  residual = hidden_states
341
  hidden_states = self.ln_1(hidden_states)
342
  attn_outputs = self.attn(
@@ -347,45 +419,21 @@ class LOLABlock(nn.Module):
347
  use_cache=use_cache,
348
  output_attentions=output_attentions,
349
  )
350
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
351
  outputs = attn_outputs[1:]
352
- # residual connection
353
  hidden_states = attn_output + residual
354
 
355
- if encoder_hidden_states is not None:
356
- # add one self-attention block for cross-attention
357
- if not hasattr(self, "crossattention"):
358
- raise ValueError(
359
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
360
- "cross-attention layers by setting `config.add_cross_attention=True`"
361
- )
362
- residual = hidden_states
363
- hidden_states = self.ln_cross_attn(hidden_states)
364
- cross_attn_outputs = self.crossattention(
365
- hidden_states,
366
- attention_mask=attention_mask,
367
- head_mask=head_mask,
368
- encoder_hidden_states=encoder_hidden_states,
369
- encoder_attention_mask=encoder_attention_mask,
370
- output_attentions=output_attentions,
371
- )
372
- attn_output = cross_attn_outputs[0]
373
- # residual connection
374
- hidden_states = residual + attn_output
375
- outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
376
-
377
  residual = hidden_states
378
  hidden_states = self.ln_2(hidden_states)
379
- feed_forward_hidden_states, _ = self.moe(hidden_states)
380
- # residual connection
381
  hidden_states = residual + feed_forward_hidden_states
382
 
383
  if use_cache:
384
- outputs = (hidden_states,) + outputs
385
  else:
386
- outputs = (hidden_states,) + outputs[1:]
387
 
388
- return outputs # hidden_states, present, (attentions, cross_attentions)
389
 
390
  class LOLAMOE(nn.Module):
391
  def __init__(self,
@@ -404,50 +452,41 @@ class LOLAMOE(nn.Module):
404
  self.experts = nn.ModuleList([GPT2MLP(inner_dim, config) for _ in range(self.num_experts)])
405
 
406
  def forward(self, hidden_states):
407
- # https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L816
408
- # FIXME do it as in top1gating
409
- # https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/sharded_moe.py
410
-
411
  batch_size, sequence_length, hidden_dim = hidden_states.shape
412
  hidden_states = hidden_states.view(-1, hidden_dim)
413
 
414
  router_logits = self.gate(hidden_states)
415
- # router_logits = router_logits.squeeze(dim=0)
 
 
 
 
416
 
417
- # TODO: fix the weights logic to be the same as Megatron
418
- routing_weights = F.softmax(router_logits, dim=1)
419
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
420
- # routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
421
- # commenting the statement above for LOLA and removing the "/" operator to avoid getting weights as 1
422
- routing_weights = routing_weights.sum(dim=-1, keepdim=True)
423
- routing_weights = routing_weights.to(hidden_states.dtype)
424
 
 
 
 
 
425
  final_hidden_states = torch.zeros(
426
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
427
  )
428
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
429
- for expert_idx in range(self.num_experts):
430
- expert_layer = self.experts[expert_idx]
431
- idx, top_x = torch.where(expert_mask[expert_idx])
432
 
433
- if top_x.shape[0] == 0:
 
 
 
434
  continue
 
 
 
 
435
 
436
- # in torch it is faster to index using lists than torch tensors
437
- top_x_list = top_x.tolist()
438
- idx_list = idx.tolist()
439
-
440
- # Index the correct hidden states and compute the expert hidden state for
441
- # the current expert. We need to make sure to multiply the output hidden
442
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
443
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
444
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
445
-
446
- # However `index_add_` only support torch tensors for indexing so we'll use
447
- # the `top_x` tensor here.
448
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
449
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
450
- return final_hidden_states, router_logits
451
 
452
  class LOLAAttention(GPT2Attention):
453
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
@@ -471,7 +510,7 @@ class LOLAAttention(GPT2Attention):
471
  self.split_size = self.embed_dim
472
  if self.head_dim * self.num_heads != self.embed_dim:
473
  raise ValueError(
474
- f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
475
  f" {self.num_heads})."
476
  )
477
 
@@ -512,156 +551,63 @@ class LOLALMHeadModel(GPT2LMHeadModel):
512
 
513
  # Initialize weights and apply final processing
514
  self.post_init()
515
-
516
-
517
- class LOLADoubleHeadsModel(GPT2DoubleHeadsModel):
518
 
519
- config_class = LOLAConfig
520
-
521
- def __init__(self, config):
522
- super(GPT2DoubleHeadsModel, self).__init__(config)
523
- config.num_labels = 1
524
- self.transformer = LOLAModel(config)
525
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
526
- self.multiple_choice_head = SequenceSummary(config)
527
-
528
- # Model parallel
529
- self.model_parallel = False
530
- self.device_map = None
531
-
532
- # Initialize weights and apply final processing
533
- self.post_init()
534
-
535
-
536
- class LOLAForSequenceClassification(GPT2ForSequenceClassification):
537
-
538
- config_class = LOLAConfig
539
-
540
- def __init__(self, config):
541
- super(GPT2ForSequenceClassification, self).__init__(config)
542
- self.num_labels = config.num_labels
543
- self.transformer = LOLAModel(config)
544
- self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
545
-
546
- # Model parallel
547
- self.model_parallel = False
548
- self.device_map = None
549
-
550
- # Initialize weights and apply final processing
551
- self.post_init()
552
-
553
- class LOLAForTokenClassification(GPT2ForTokenClassification):
554
-
555
- config_class = LOLAConfig
556
-
557
- def __init__(self, config):
558
- super(GPT2ForTokenClassification, self).__init__(config)
559
- self.num_labels = config.num_labels
560
-
561
- self.transformer = LOLAModel(config)
562
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
563
- classifier_dropout = config.classifier_dropout
564
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
565
- classifier_dropout = config.hidden_dropout
566
- else:
567
- classifier_dropout = 0.1
568
- self.dropout = nn.Dropout(classifier_dropout)
569
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
570
-
571
- # Model parallel
572
- self.model_parallel = False
573
- self.device_map = None
574
-
575
- # Initialize weights and apply final processing
576
- self.post_init()
577
-
578
- class LOLAForQuestionAnswering(GPT2PreTrainedModel):
579
-
580
- config_class = LOLAConfig
581
-
582
- def __init__(self, config):
583
- super().__init__(config)
584
- self.num_labels = config.num_labels
585
- self.transformer = LOLAModel(config)
586
- self.qa_outputs = nn.Linear(config.hidden_size, 2)
587
-
588
- # Model parallel
589
- self.model_parallel = False
590
- self.device_map = None
591
-
592
- # Initialize weights and apply final processing
593
- self.post_init()
594
-
595
  def forward(
596
  self,
597
- input_ids: Optional[torch.LongTensor] = None,
598
- attention_mask: Optional[torch.FloatTensor] = None,
599
- token_type_ids: Optional[torch.LongTensor] = None,
600
- position_ids: Optional[torch.LongTensor] = None,
601
- head_mask: Optional[torch.FloatTensor] = None,
602
- inputs_embeds: Optional[torch.FloatTensor] = None,
603
- start_positions: Optional[torch.LongTensor] = None,
604
- end_positions: Optional[torch.LongTensor] = None,
605
- output_attentions: Optional[bool] = None,
606
- output_hidden_states: Optional[bool] = None,
607
- return_dict: Optional[bool] = None,
608
- ) -> Union[Tuple, QuestionAnsweringModelOutput]:
609
- r"""
610
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
611
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
612
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
613
- are not taken into account for computing the loss.
614
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
615
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
616
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
617
- are not taken into account for computing the loss.
618
- """
619
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
620
 
621
- outputs = self.transformer(
622
  input_ids,
 
623
  attention_mask=attention_mask,
624
  token_type_ids=token_type_ids,
625
  position_ids=position_ids,
626
  head_mask=head_mask,
627
  inputs_embeds=inputs_embeds,
 
628
  output_attentions=output_attentions,
629
  output_hidden_states=output_hidden_states,
630
- return_dict=return_dict,
631
  )
632
-
633
- sequence_output = outputs[0]
634
-
635
- logits = self.qa_outputs(sequence_output)
636
- start_logits, end_logits = logits.split(1, dim=-1)
637
- start_logits = start_logits.squeeze(-1).contiguous()
638
- end_logits = end_logits.squeeze(-1).contiguous()
639
-
640
- total_loss = None
641
- if start_positions is not None and end_positions is not None:
642
- # If we are on multi-GPU, split add a dimension
643
- if len(start_positions.size()) > 1:
644
- start_positions = start_positions.squeeze(-1).to(start_logits.device)
645
- if len(end_positions.size()) > 1:
646
- end_positions = end_positions.squeeze(-1).to(end_logits.device)
647
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
648
- ignored_index = start_logits.size(1)
649
- start_positions = start_positions.clamp(0, ignored_index)
650
- end_positions = end_positions.clamp(0, ignored_index)
651
-
652
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
653
- start_loss = loss_fct(start_logits, start_positions)
654
- end_loss = loss_fct(end_logits, end_positions)
655
- total_loss = (start_loss + end_loss) / 2
656
 
657
  if not return_dict:
658
- output = (start_logits, end_logits) + outputs[2:]
659
- return ((total_loss,) + output) if total_loss is not None else output
660
-
661
- return QuestionAnsweringModelOutput(
662
- loss=total_loss,
663
- start_logits=start_logits,
664
- end_logits=end_logits,
665
- hidden_states=outputs.hidden_states,
666
- attentions=outputs.attentions,
 
 
667
  )
 
7
  import warnings
8
  from typing import Optional, Tuple, Union
9
 
10
+ ## Uncomment the below three and comment the other import for model conversion
11
+ #import sys
12
+ # sys.path.append(".")
13
+ # from configuration_lola_gpt2 import LOLAConfig
14
+
15
  from .configuration_lola_gpt2 import LOLAConfig
16
  import torch
17
  import torch.utils.checkpoint
 
21
 
22
  from transformers.modeling_outputs import (
23
  BaseModelOutputWithPastAndCrossAttentions,
24
+ MoeCausalLMOutputWithPast,
25
  SequenceClassifierOutputWithPast,
26
  QuestionAnsweringModelOutput
27
  )
 
33
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
34
 
35
  from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP, GPT2Block, GPT2PreTrainedModel
36
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
37
+ from dataclasses import dataclass
38
+ from typing import Optional, Tuple
39
+
40
+ import torch
41
+ from transformers.modeling_outputs import ModelOutput
42
+ import transformers
43
+ import importlib.util
44
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
48
+ expert_analysis_callback = lambda _: None
49
+
50
+ class LOLADependencyChecker:
51
+ def __init__(self):
52
+ self.expected_versions = {
53
+ "transformers": "4.38.2"
54
+ }
55
+ self.check_dependencies()
56
+
57
+ def check_dependencies(self):
58
+ # Check transformers version
59
+ self._check_version("transformers", transformers.__version__)
60
+
61
+ def _check_version(self, package_name, installed_version):
62
+ expected_version = self.expected_versions.get(package_name)
63
+ if installed_version != expected_version:
64
+ warnings.warn(
65
+ f"Warning: The installed {package_name} version ({installed_version}) "
66
+ f"differs from the expected version ({expected_version}). "
67
+ "This may lead to unexpected behavior.",
68
+ category=UserWarning
69
+ )
70
+
71
+ @dataclass
72
+ class MoeModelOutputWithPast(ModelOutput):
73
+ """
74
+ Base class for model's outputs with potential hidden states and attentions, and includes auxiliary loss.
75
+
76
+ Args:
77
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
78
+ Sequence of hidden-states at the output of the last layer of the model.
79
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
80
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
81
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
82
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
83
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
84
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
85
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed):
86
+ Router logits computed by MoE routers, used to compute the auxiliary loss for Mixture of Experts models.
87
+ aux_loss (`torch.FloatTensor`, *optional*):
88
+ The auxiliary loss computed from the MoE layers, used to encourage balanced expert utilization.
89
+ """
90
+
91
+ last_hidden_state: torch.FloatTensor = None
92
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, ...]]] = None
93
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
94
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
95
+ router_logits: Optional[Tuple[torch.FloatTensor, ...]] = None
96
+ aux_loss: Optional[torch.FloatTensor] = None
97
+
98
  # LOLA
99
  class LOLAModel(GPT2PreTrainedModel):
100
 
 
102
 
103
  def __init__(self, config):
104
  super().__init__(config)
105
+ # Checking dependencies version
106
+ LOLADependencyChecker()
107
+
108
  self.embed_dim = config.hidden_size
109
 
110
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
 
128
  def parallelize(self, device_map=None):
129
  # Check validity of device_map
130
  warnings.warn(
131
+ "GPT2Model.parallelize is deprecated and will be removed in v5 of Transformers, you should load your"
132
+ " model with device_map='balanced' in the call to from_pretrained. You can also provide your own"
133
+ " device_map but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
134
  " ...}",
135
  FutureWarning,
136
  )
 
154
 
155
  def deparallelize(self):
156
  warnings.warn(
157
+ "Like parallelize, deparallelize is deprecated and will be removed in v5 of Transformers.",
158
  FutureWarning,
159
  )
160
  self.model_parallel = False
 
284
  if self.gradient_checkpointing and self.training:
285
  if use_cache:
286
  logger.warning_once(
287
+ "use_cache=True is incompatible with gradient checkpointing. Setting use_cache=False..."
288
  )
289
  use_cache = False
290
 
 
292
  all_self_attentions = () if output_attentions else None
293
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
294
  all_hidden_states = () if output_hidden_states else None
295
+ aux_losses = []
296
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
297
  # Model parallel
298
  if self.model_parallel:
 
335
  hidden_states = outputs[0]
336
  if use_cache is True:
337
  presents = presents + (outputs[1],)
338
+
339
+ if isinstance(block, LOLABlock):
340
+ # Collect auxiliary loss
341
+ aux_loss = outputs[-1]
342
+ aux_losses.append(aux_loss)
343
 
344
  if output_attentions:
345
  all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
 
 
346
 
347
  # Model Parallel: If it's the last layer for that device, put things on the next device
348
  if self.model_parallel:
 
353
  hidden_states = self.ln_f(hidden_states)
354
 
355
  hidden_states = hidden_states.view(output_shape)
356
+ # Aggregate auxiliary losses
357
+ if aux_losses:
358
+ total_aux_loss = torch.stack(aux_losses).sum()
359
+ else:
360
+ total_aux_loss = None
361
  # Add last hidden state
362
  if output_hidden_states:
363
  all_hidden_states = all_hidden_states + (hidden_states,)
 
364
  if not return_dict:
365
+ output = (hidden_states, presents, all_hidden_states, all_self_attentions)
366
+ if total_aux_loss is not None:
367
+ output += (total_aux_loss,)
368
+ return tuple(v for v in output if v is not None)
 
369
 
370
+ return MoeModelOutputWithPast(
371
  last_hidden_state=hidden_states,
372
  past_key_values=presents,
373
  hidden_states=all_hidden_states,
374
  attentions=all_self_attentions,
375
+ router_logits=None, # Include if router_logits are needed
376
+ aux_loss=total_aux_loss,
377
  )
378
 
379
  class LOLABlock(nn.Module):
 
385
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
386
  self.attn = GPT2Attention(config, layer_idx=layer_idx)
387
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
 
388
  self.moe = LOLAMOE(
389
  hidden_size,
390
  inner_dim,
 
408
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
409
  use_cache: Optional[bool] = False,
410
  output_attentions: Optional[bool] = False,
411
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
412
  residual = hidden_states
413
  hidden_states = self.ln_1(hidden_states)
414
  attn_outputs = self.attn(
 
419
  use_cache=use_cache,
420
  output_attentions=output_attentions,
421
  )
422
+ attn_output = attn_outputs[0]
423
  outputs = attn_outputs[1:]
 
424
  hidden_states = attn_output + residual
425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  residual = hidden_states
427
  hidden_states = self.ln_2(hidden_states)
428
+ feed_forward_hidden_states, router_logits, aux_loss = self.moe(hidden_states)
 
429
  hidden_states = residual + feed_forward_hidden_states
430
 
431
  if use_cache:
432
+ outputs = (hidden_states,) + outputs + (aux_loss,)
433
  else:
434
+ outputs = (hidden_states,) + outputs + (aux_loss,)
435
 
436
+ return outputs # hidden_states, present, (attentions), aux_loss
437
 
438
  class LOLAMOE(nn.Module):
439
  def __init__(self,
 
452
  self.experts = nn.ModuleList([GPT2MLP(inner_dim, config) for _ in range(self.num_experts)])
453
 
454
  def forward(self, hidden_states):
 
 
 
 
455
  batch_size, sequence_length, hidden_dim = hidden_states.shape
456
  hidden_states = hidden_states.view(-1, hidden_dim)
457
 
458
  router_logits = self.gate(hidden_states)
459
+ routing_probabilities = F.softmax(router_logits, dim=1)
460
+ routing_weights, selected_experts = torch.topk(routing_probabilities, self.top_k, dim=-1)
461
+ # Compute Expert Mask
462
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts)
463
+ expert_mask = expert_mask.sum(dim=1) # Shape: [batch_size * seq_length, num_experts]
464
 
465
+ # Compute Tokens per Expert and Router Probabilities
466
+ token_fraction_per_expert = expert_mask.float().sum(dim=0) / expert_mask.float().sum()
467
+ mean_router_prob_per_expert = routing_probabilities.mean(dim=0)
 
 
 
 
468
 
469
+ # Calculate Auxiliary Loss
470
+ aux_loss = torch.sum(token_fraction_per_expert * mean_router_prob_per_expert) * self.num_experts
471
+
472
+ # Proceed with MoE computation as before
473
  final_hidden_states = torch.zeros(
474
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
475
  )
 
 
 
 
476
 
477
+ # Process tokens for each expert
478
+ for expert_idx in range(self.num_experts):
479
+ indices = (selected_experts == expert_idx).nonzero(as_tuple=True)[0]
480
+ if indices.numel() == 0:
481
  continue
482
+ current_states = hidden_states[indices]
483
+ current_output = self.experts[expert_idx](current_states)
484
+ current_weights = routing_weights[indices, (selected_experts[indices] == expert_idx).nonzero(as_tuple=True)[1]]
485
+ final_hidden_states.index_add_(0, indices, current_output * current_weights.unsqueeze(-1))
486
 
487
+ final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
488
+ expert_analysis_callback(selected_experts)
489
+ return final_hidden_states, router_logits, aux_loss
 
 
 
 
 
 
 
 
 
 
 
 
490
 
491
  class LOLAAttention(GPT2Attention):
492
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
 
510
  self.split_size = self.embed_dim
511
  if self.head_dim * self.num_heads != self.embed_dim:
512
  raise ValueError(
513
+ f"embed_dim must be divisible by num_heads (got embed_dim: {self.embed_dim} and num_heads:"
514
  f" {self.num_heads})."
515
  )
516
 
 
551
 
552
  # Initialize weights and apply final processing
553
  self.post_init()
 
 
 
554
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  def forward(
556
  self,
557
+ input_ids=None,
558
+ past_key_values=None,
559
+ attention_mask=None,
560
+ token_type_ids=None,
561
+ position_ids=None,
562
+ head_mask=None,
563
+ inputs_embeds=None,
564
+ labels=None,
565
+ use_cache=None,
566
+ output_attentions=None,
567
+ output_hidden_states=None,
568
+ return_dict=None,
569
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
 
 
 
 
 
 
 
 
 
570
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
571
 
572
+ transformer_outputs = self.transformer(
573
  input_ids,
574
+ past_key_values=past_key_values,
575
  attention_mask=attention_mask,
576
  token_type_ids=token_type_ids,
577
  position_ids=position_ids,
578
  head_mask=head_mask,
579
  inputs_embeds=inputs_embeds,
580
+ use_cache=use_cache,
581
  output_attentions=output_attentions,
582
  output_hidden_states=output_hidden_states,
583
+ return_dict=True, # Ensure we get a MoeModelOutputWithPast
584
  )
585
+ hidden_states = transformer_outputs.last_hidden_state
586
+ lm_logits = self.lm_head(hidden_states)
587
+
588
+ aux_loss = transformer_outputs.aux_loss if hasattr(transformer_outputs, 'aux_loss') else None
589
+
590
+ loss = None
591
+ if labels is not None:
592
+ # Shift so that tokens < n predict n
593
+ shift_logits = lm_logits[..., :-1, :].contiguous()
594
+ shift_labels = labels[..., 1:].contiguous()
595
+ # Flatten the tokens
596
+ loss_fct = CrossEntropyLoss()
597
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
598
+ if aux_loss is not None:
599
+ loss += self.config.router_aux_loss_coef * aux_loss
 
 
 
 
 
 
 
 
 
600
 
601
  if not return_dict:
602
+ output = (lm_logits,) + transformer_outputs[1:]
603
+ return ((loss,) + output) if loss is not None else output
604
+
605
+ return MoeCausalLMOutputWithPast(
606
+ loss=loss,
607
+ aux_loss=aux_loss,
608
+ logits=lm_logits,
609
+ past_key_values=transformer_outputs.past_key_values,
610
+ hidden_states=transformer_outputs.hidden_states,
611
+ attentions=transformer_outputs.attentions,
612
+ router_logits=transformer_outputs.router_logits,
613
  )