Text Generation
Transformers
Safetensors
lola_v1
custom_code
neo-nlp-dev commited on
Commit
3a7a4b9
·
1 Parent(s): 0507803

updating model class

Browse files

- updating attention logic for GPT2Block to select GPT2Attention
- adding consider_aux_loss config to allow users to skip adding aux loss to the total loss

Files changed (2) hide show
  1. configuration_lola_gpt2.py +2 -0
  2. modeling_lola_gpt2.py +10 -55
configuration_lola_gpt2.py CHANGED
@@ -48,6 +48,7 @@ class LOLAConfig(PretrainedConfig):
48
  num_experts=16,
49
  topk=1,
50
  router_aux_loss_coef=0.01,
 
51
  **kwargs,
52
  ):
53
  self.vocab_size = vocab_size
@@ -77,6 +78,7 @@ class LOLAConfig(PretrainedConfig):
77
  self.bos_token_id = bos_token_id
78
  self.eos_token_id = eos_token_id
79
  self.router_aux_loss_coef = router_aux_loss_coef
 
80
 
81
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
82
 
 
48
  num_experts=16,
49
  topk=1,
50
  router_aux_loss_coef=0.01,
51
+ consider_aux_loss=True,
52
  **kwargs,
53
  ):
54
  self.vocab_size = vocab_size
 
78
  self.bos_token_id = bos_token_id
79
  self.eos_token_id = eos_token_id
80
  self.router_aux_loss_coef = router_aux_loss_coef
81
+ self.consider_aux_loss = consider_aux_loss
82
 
83
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
84
 
modeling_lola_gpt2.py CHANGED
@@ -21,12 +21,8 @@ from torch.nn import CrossEntropyLoss
21
 
22
  from transformers.modeling_outputs import (
23
  BaseModelOutputWithPastAndCrossAttentions,
24
- MoeCausalLMOutputWithPast,
25
- SequenceClassifierOutputWithPast,
26
- QuestionAnsweringModelOutput
27
  )
28
- from transformers.modeling_utils import SequenceSummary
29
- from transformers.pytorch_utils import Conv1D
30
  from transformers.utils import (
31
  logging
32
  )
@@ -40,7 +36,6 @@ from typing import Optional, Tuple
40
  import torch
41
  from transformers.modeling_outputs import ModelOutput
42
  import transformers
43
- import importlib.util
44
 
45
 
46
  logger = logging.get_logger(__name__)
@@ -50,7 +45,7 @@ expert_analysis_callback = lambda _: None
50
  class LOLADependencyChecker:
51
  def __init__(self):
52
  self.expected_versions = {
53
- "transformers": "4.38.2"
54
  }
55
  self.check_dependencies()
56
 
@@ -111,6 +106,8 @@ class LOLAModel(GPT2PreTrainedModel):
111
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
112
 
113
  self.drop = nn.Dropout(config.embd_pdrop)
 
 
114
  self.h = nn.ModuleList([
115
  GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
116
  ])
@@ -384,6 +381,7 @@ class LOLABlock(nn.Module):
384
 
385
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
386
  self.attn = GPT2Attention(config, layer_idx=layer_idx)
 
387
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
388
  self.moe = LOLAMOE(
389
  hidden_size,
@@ -488,53 +486,6 @@ class LOLAMOE(nn.Module):
488
  expert_analysis_callback(selected_experts)
489
  return final_hidden_states, router_logits, aux_loss
490
 
491
- class LOLAAttention(GPT2Attention):
492
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
493
- super(GPT2Attention, SequenceClassifierOutputWithPast).__init__()
494
-
495
- max_positions = config.max_position_embeddings
496
- self.register_buffer(
497
- "bias",
498
- torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
499
- 1, 1, max_positions, max_positions
500
- ),
501
- #persistent=False,
502
- )
503
- self.register_buffer("masked_bias", torch.tensor(-1e4),
504
- #persistent=False
505
- )
506
-
507
- self.embed_dim = config.hidden_size
508
- self.num_heads = config.num_attention_heads
509
- self.head_dim = self.embed_dim // self.num_heads
510
- self.split_size = self.embed_dim
511
- if self.head_dim * self.num_heads != self.embed_dim:
512
- raise ValueError(
513
- f"embed_dim must be divisible by num_heads (got embed_dim: {self.embed_dim} and num_heads:"
514
- f" {self.num_heads})."
515
- )
516
-
517
- self.scale_attn_weights = config.scale_attn_weights
518
- self.is_cross_attention = is_cross_attention
519
-
520
- # Layer-wise attention scaling, reordering, and upcasting
521
- self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
522
- self.layer_idx = layer_idx
523
- self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
524
-
525
- if self.is_cross_attention:
526
- self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
527
- self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
528
- else:
529
- self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
530
- self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
531
-
532
- self.attn_dropout = nn.Dropout(config.attn_pdrop)
533
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
534
-
535
- self.pruned_heads = set()
536
-
537
-
538
  class LOLALMHeadModel(GPT2LMHeadModel):
539
 
540
  config_class = LOLAConfig
@@ -545,6 +496,9 @@ class LOLALMHeadModel(GPT2LMHeadModel):
545
  self.transformer = LOLAModel(config)
546
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
547
 
 
 
 
548
  # Model parallel
549
  self.model_parallel = False
550
  self.device_map = None
@@ -595,7 +549,8 @@ class LOLALMHeadModel(GPT2LMHeadModel):
595
  # Flatten the tokens
596
  loss_fct = CrossEntropyLoss()
597
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
598
- if aux_loss is not None:
 
599
  loss += self.config.router_aux_loss_coef * aux_loss
600
 
601
  if not return_dict:
 
21
 
22
  from transformers.modeling_outputs import (
23
  BaseModelOutputWithPastAndCrossAttentions,
24
+ MoeCausalLMOutputWithPast
 
 
25
  )
 
 
26
  from transformers.utils import (
27
  logging
28
  )
 
36
  import torch
37
  from transformers.modeling_outputs import ModelOutput
38
  import transformers
 
39
 
40
 
41
  logger = logging.get_logger(__name__)
 
45
  class LOLADependencyChecker:
46
  def __init__(self):
47
  self.expected_versions = {
48
+ "transformers": "4.47.0"
49
  }
50
  self.check_dependencies()
51
 
 
106
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
107
 
108
  self.drop = nn.Dropout(config.embd_pdrop)
109
+ # To make sure the GPTBlock selects the right attention
110
+ config._attn_implementation='eager'
111
  self.h = nn.ModuleList([
112
  GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
113
  ])
 
381
 
382
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
383
  self.attn = GPT2Attention(config, layer_idx=layer_idx)
384
+ #self.attn = GPT2SdpaAttention(config, layer_idx=layer_idx)
385
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
386
  self.moe = LOLAMOE(
387
  hidden_size,
 
486
  expert_analysis_callback(selected_experts)
487
  return final_hidden_states, router_logits, aux_loss
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  class LOLALMHeadModel(GPT2LMHeadModel):
490
 
491
  config_class = LOLAConfig
 
496
  self.transformer = LOLAModel(config)
497
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
498
 
499
+ # To add aux loss or not
500
+ self.consider_aux_loss = config.consider_aux_loss
501
+ logger.debug(f'consider_aux_loss is set to {self.consider_aux_loss}')
502
  # Model parallel
503
  self.model_parallel = False
504
  self.device_map = None
 
549
  # Flatten the tokens
550
  loss_fct = CrossEntropyLoss()
551
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
552
+ # We can avoid adding aux loss to the total loss if its not needed (e.g. LORA without targeting expert-gating)
553
+ if aux_loss is not None and self.consider_aux_loss:
554
  loss += self.config.router_aux_loss_coef * aux_loss
555
 
556
  if not return_dict: