neo-nlp-dev
commited on
Commit
·
3a7a4b9
1
Parent(s):
0507803
updating model class
Browse files- updating attention logic for GPT2Block to select GPT2Attention
- adding consider_aux_loss config to allow users to skip adding aux loss to the total loss
- configuration_lola_gpt2.py +2 -0
- modeling_lola_gpt2.py +10 -55
configuration_lola_gpt2.py
CHANGED
@@ -48,6 +48,7 @@ class LOLAConfig(PretrainedConfig):
|
|
48 |
num_experts=16,
|
49 |
topk=1,
|
50 |
router_aux_loss_coef=0.01,
|
|
|
51 |
**kwargs,
|
52 |
):
|
53 |
self.vocab_size = vocab_size
|
@@ -77,6 +78,7 @@ class LOLAConfig(PretrainedConfig):
|
|
77 |
self.bos_token_id = bos_token_id
|
78 |
self.eos_token_id = eos_token_id
|
79 |
self.router_aux_loss_coef = router_aux_loss_coef
|
|
|
80 |
|
81 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
82 |
|
|
|
48 |
num_experts=16,
|
49 |
topk=1,
|
50 |
router_aux_loss_coef=0.01,
|
51 |
+
consider_aux_loss=True,
|
52 |
**kwargs,
|
53 |
):
|
54 |
self.vocab_size = vocab_size
|
|
|
78 |
self.bos_token_id = bos_token_id
|
79 |
self.eos_token_id = eos_token_id
|
80 |
self.router_aux_loss_coef = router_aux_loss_coef
|
81 |
+
self.consider_aux_loss = consider_aux_loss
|
82 |
|
83 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
84 |
|
modeling_lola_gpt2.py
CHANGED
@@ -21,12 +21,8 @@ from torch.nn import CrossEntropyLoss
|
|
21 |
|
22 |
from transformers.modeling_outputs import (
|
23 |
BaseModelOutputWithPastAndCrossAttentions,
|
24 |
-
MoeCausalLMOutputWithPast
|
25 |
-
SequenceClassifierOutputWithPast,
|
26 |
-
QuestionAnsweringModelOutput
|
27 |
)
|
28 |
-
from transformers.modeling_utils import SequenceSummary
|
29 |
-
from transformers.pytorch_utils import Conv1D
|
30 |
from transformers.utils import (
|
31 |
logging
|
32 |
)
|
@@ -40,7 +36,6 @@ from typing import Optional, Tuple
|
|
40 |
import torch
|
41 |
from transformers.modeling_outputs import ModelOutput
|
42 |
import transformers
|
43 |
-
import importlib.util
|
44 |
|
45 |
|
46 |
logger = logging.get_logger(__name__)
|
@@ -50,7 +45,7 @@ expert_analysis_callback = lambda _: None
|
|
50 |
class LOLADependencyChecker:
|
51 |
def __init__(self):
|
52 |
self.expected_versions = {
|
53 |
-
"transformers": "4.
|
54 |
}
|
55 |
self.check_dependencies()
|
56 |
|
@@ -111,6 +106,8 @@ class LOLAModel(GPT2PreTrainedModel):
|
|
111 |
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
112 |
|
113 |
self.drop = nn.Dropout(config.embd_pdrop)
|
|
|
|
|
114 |
self.h = nn.ModuleList([
|
115 |
GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
|
116 |
])
|
@@ -384,6 +381,7 @@ class LOLABlock(nn.Module):
|
|
384 |
|
385 |
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
386 |
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
|
|
387 |
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
388 |
self.moe = LOLAMOE(
|
389 |
hidden_size,
|
@@ -488,53 +486,6 @@ class LOLAMOE(nn.Module):
|
|
488 |
expert_analysis_callback(selected_experts)
|
489 |
return final_hidden_states, router_logits, aux_loss
|
490 |
|
491 |
-
class LOLAAttention(GPT2Attention):
|
492 |
-
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
493 |
-
super(GPT2Attention, SequenceClassifierOutputWithPast).__init__()
|
494 |
-
|
495 |
-
max_positions = config.max_position_embeddings
|
496 |
-
self.register_buffer(
|
497 |
-
"bias",
|
498 |
-
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
499 |
-
1, 1, max_positions, max_positions
|
500 |
-
),
|
501 |
-
#persistent=False,
|
502 |
-
)
|
503 |
-
self.register_buffer("masked_bias", torch.tensor(-1e4),
|
504 |
-
#persistent=False
|
505 |
-
)
|
506 |
-
|
507 |
-
self.embed_dim = config.hidden_size
|
508 |
-
self.num_heads = config.num_attention_heads
|
509 |
-
self.head_dim = self.embed_dim // self.num_heads
|
510 |
-
self.split_size = self.embed_dim
|
511 |
-
if self.head_dim * self.num_heads != self.embed_dim:
|
512 |
-
raise ValueError(
|
513 |
-
f"embed_dim must be divisible by num_heads (got embed_dim: {self.embed_dim} and num_heads:"
|
514 |
-
f" {self.num_heads})."
|
515 |
-
)
|
516 |
-
|
517 |
-
self.scale_attn_weights = config.scale_attn_weights
|
518 |
-
self.is_cross_attention = is_cross_attention
|
519 |
-
|
520 |
-
# Layer-wise attention scaling, reordering, and upcasting
|
521 |
-
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
522 |
-
self.layer_idx = layer_idx
|
523 |
-
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
524 |
-
|
525 |
-
if self.is_cross_attention:
|
526 |
-
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
527 |
-
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
528 |
-
else:
|
529 |
-
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
530 |
-
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
531 |
-
|
532 |
-
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
533 |
-
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
534 |
-
|
535 |
-
self.pruned_heads = set()
|
536 |
-
|
537 |
-
|
538 |
class LOLALMHeadModel(GPT2LMHeadModel):
|
539 |
|
540 |
config_class = LOLAConfig
|
@@ -545,6 +496,9 @@ class LOLALMHeadModel(GPT2LMHeadModel):
|
|
545 |
self.transformer = LOLAModel(config)
|
546 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
547 |
|
|
|
|
|
|
|
548 |
# Model parallel
|
549 |
self.model_parallel = False
|
550 |
self.device_map = None
|
@@ -595,7 +549,8 @@ class LOLALMHeadModel(GPT2LMHeadModel):
|
|
595 |
# Flatten the tokens
|
596 |
loss_fct = CrossEntropyLoss()
|
597 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
598 |
-
if
|
|
|
599 |
loss += self.config.router_aux_loss_coef * aux_loss
|
600 |
|
601 |
if not return_dict:
|
|
|
21 |
|
22 |
from transformers.modeling_outputs import (
|
23 |
BaseModelOutputWithPastAndCrossAttentions,
|
24 |
+
MoeCausalLMOutputWithPast
|
|
|
|
|
25 |
)
|
|
|
|
|
26 |
from transformers.utils import (
|
27 |
logging
|
28 |
)
|
|
|
36 |
import torch
|
37 |
from transformers.modeling_outputs import ModelOutput
|
38 |
import transformers
|
|
|
39 |
|
40 |
|
41 |
logger = logging.get_logger(__name__)
|
|
|
45 |
class LOLADependencyChecker:
|
46 |
def __init__(self):
|
47 |
self.expected_versions = {
|
48 |
+
"transformers": "4.47.0"
|
49 |
}
|
50 |
self.check_dependencies()
|
51 |
|
|
|
106 |
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
107 |
|
108 |
self.drop = nn.Dropout(config.embd_pdrop)
|
109 |
+
# To make sure the GPTBlock selects the right attention
|
110 |
+
config._attn_implementation='eager'
|
111 |
self.h = nn.ModuleList([
|
112 |
GPT2Block(config, layer_idx=i) if i % 2 == 0 else LOLABlock(config, layer_idx=i) for i in range(config.num_hidden_layers)
|
113 |
])
|
|
|
381 |
|
382 |
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
383 |
self.attn = GPT2Attention(config, layer_idx=layer_idx)
|
384 |
+
#self.attn = GPT2SdpaAttention(config, layer_idx=layer_idx)
|
385 |
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
386 |
self.moe = LOLAMOE(
|
387 |
hidden_size,
|
|
|
486 |
expert_analysis_callback(selected_experts)
|
487 |
return final_hidden_states, router_logits, aux_loss
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
class LOLALMHeadModel(GPT2LMHeadModel):
|
490 |
|
491 |
config_class = LOLAConfig
|
|
|
496 |
self.transformer = LOLAModel(config)
|
497 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
498 |
|
499 |
+
# To add aux loss or not
|
500 |
+
self.consider_aux_loss = config.consider_aux_loss
|
501 |
+
logger.debug(f'consider_aux_loss is set to {self.consider_aux_loss}')
|
502 |
# Model parallel
|
503 |
self.model_parallel = False
|
504 |
self.device_map = None
|
|
|
549 |
# Flatten the tokens
|
550 |
loss_fct = CrossEntropyLoss()
|
551 |
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
552 |
+
# We can avoid adding aux loss to the total loss if its not needed (e.g. LORA without targeting expert-gating)
|
553 |
+
if aux_loss is not None and self.consider_aux_loss:
|
554 |
loss += self.config.router_aux_loss_coef * aux_loss
|
555 |
|
556 |
if not return_dict:
|