Crystalcareai commited on
Commit
cb869dc
·
verified ·
1 Parent(s): 86a52a5

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +19 -39
modeling_gemmoe.py CHANGED
@@ -25,6 +25,8 @@ import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
 
 
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
@@ -629,61 +631,39 @@ GEMMOE_ATTENTION_CLASSES = {
629
  "sdpa": GemmoeSdpaAttention,
630
  }
631
 
632
- class GemmoeBlockSparseTop2MLP(nn.Module):
633
- def __init__(self, config: GemmoeConfig):
634
- super().__init__()
635
- self.ffn_dim = config.intermediate_size
636
- self.hidden_dim = config.hidden_size
637
-
638
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
639
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
640
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
641
-
642
- self.act_fn = approx_gelu
643
-
644
- def forward(self, hidden_states):
645
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
646
- current_hidden_states = self.w2(current_hidden_states)
647
- return current_hidden_states.to(hidden_states.dtype)
648
-
649
-
650
  class GemmoeSparseMoeBlock(nn.Module):
651
  def __init__(self, config):
652
  super().__init__()
653
  self.hidden_dim = config.hidden_size
654
  self.ffn_dim = config.intermediate_size
655
  self.num_experts = config.num_local_experts
656
- self.top_k = 2
657
 
658
  # gating
659
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
660
 
661
- self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
 
 
 
 
 
 
662
 
663
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
664
  batch_size, sequence_length, hidden_dim = hidden_states.shape
665
  hidden_states = hidden_states.view(-1, hidden_dim)
666
-
667
- # router_logits: (batch * sequence_length, n_experts)
668
  router_logits = self.gate(hidden_states)
669
- routing_weights = F.softmax(router_logits, dim=1)
670
- topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
671
- topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
672
-
673
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
674
-
675
- y = torch.empty_like(hidden_states)
676
-
677
- flat_topk_idx = topk_idx.view(-1)
678
- for i in range(self.num_experts):
679
- expert = self.experts[i]
680
- expert_output = expert(hidden_states[flat_topk_idx == i])
681
- y[flat_topk_idx == i] = expert_output
682
 
683
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
684
 
685
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
686
- return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
687
 
688
 
689
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
+ from .mlp import MLP
29
+
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
  from transformers.modeling_attn_mask_utils import (
 
631
  "sdpa": GemmoeSdpaAttention,
632
  }
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  class GemmoeSparseMoeBlock(nn.Module):
635
  def __init__(self, config):
636
  super().__init__()
637
  self.hidden_dim = config.hidden_size
638
  self.ffn_dim = config.intermediate_size
639
  self.num_experts = config.num_local_experts
640
+ self.top_k = config.num_experts_per_tok
641
 
642
  # gating
643
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
644
 
645
+ self.mlp = MLP(
646
+ input_size=self.hidden_dim,
647
+ hidden_size=self.ffn_dim,
648
+ activation=nn.GELU(),
649
+ num_experts=self.num_experts,
650
+ top_k=self.top_k
651
+ )
652
 
653
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
654
  batch_size, sequence_length, hidden_dim = hidden_states.shape
655
  hidden_states = hidden_states.view(-1, hidden_dim)
656
+
 
657
  router_logits = self.gate(hidden_states)
658
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
659
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
660
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
661
+ routing_weights = routing_weights.to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
662
 
663
+ hidden_states = self.mlp(hidden_states, routing_weights, selected_experts)
664
+ hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
665
 
666
+ return hidden_states, router_logits
 
667
 
668
 
669
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe