Crystalcareai commited on
Commit
7a82a8d
·
verified ·
1 Parent(s): 6317b4a

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +39 -19
modeling_gemmoe.py CHANGED
@@ -25,8 +25,6 @@ import torch.utils.checkpoint
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
- from .mlp import MLP
29
-
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
32
  from transformers.modeling_attn_mask_utils import (
@@ -631,39 +629,61 @@ GEMMOE_ATTENTION_CLASSES = {
631
  "sdpa": GemmoeSdpaAttention,
632
  }
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  class GemmoeSparseMoeBlock(nn.Module):
635
  def __init__(self, config):
636
  super().__init__()
637
  self.hidden_dim = config.hidden_size
638
  self.ffn_dim = config.intermediate_size
639
  self.num_experts = config.num_local_experts
640
- self.top_k = config.num_experts_per_tok
641
 
642
  # gating
643
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
644
 
645
- self.mlp = MLP(
646
- input_size=self.hidden_dim,
647
- hidden_size=self.ffn_dim,
648
- activation=nn.GELU(),
649
- num_experts=self.num_experts,
650
- top_k=self.top_k
651
- )
652
 
653
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
654
  batch_size, sequence_length, hidden_dim = hidden_states.shape
655
  hidden_states = hidden_states.view(-1, hidden_dim)
656
-
 
657
  router_logits = self.gate(hidden_states)
658
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
659
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
660
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
661
- routing_weights = routing_weights.to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
662
 
663
- hidden_states = self.mlp(hidden_states, routing_weights, selected_experts)
664
- hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
665
 
666
- return hidden_states, router_logits
 
667
 
668
 
669
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe
 
25
  from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
 
 
28
  from transformers.activations import ACT2FN
29
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
 
629
  "sdpa": GemmoeSdpaAttention,
630
  }
631
 
632
+ class GemmoeBlockSparseTop2MLP(nn.Module):
633
+ def __init__(self, config: GemmoeConfig):
634
+ super().__init__()
635
+ self.ffn_dim = config.intermediate_size
636
+ self.hidden_dim = config.hidden_size
637
+
638
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
639
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
640
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
641
+
642
+ self.act_fn = approx_gelu
643
+
644
+ def forward(self, hidden_states):
645
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
646
+ current_hidden_states = self.w2(current_hidden_states)
647
+ return current_hidden_states.to(hidden_states.dtype)
648
+
649
+
650
  class GemmoeSparseMoeBlock(nn.Module):
651
  def __init__(self, config):
652
  super().__init__()
653
  self.hidden_dim = config.hidden_size
654
  self.ffn_dim = config.intermediate_size
655
  self.num_experts = config.num_local_experts
656
+ self.top_k = 2
657
 
658
  # gating
659
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
660
 
661
+ self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
 
 
 
 
 
 
662
 
663
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
664
  batch_size, sequence_length, hidden_dim = hidden_states.shape
665
  hidden_states = hidden_states.view(-1, hidden_dim)
666
+
667
+ # router_logits: (batch * sequence_length, n_experts)
668
  router_logits = self.gate(hidden_states)
669
+ routing_weights = F.softmax(router_logits, dim=1)
670
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
671
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
672
+
673
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
674
+
675
+ y = torch.empty_like(hidden_states)
676
+
677
+ flat_topk_idx = topk_idx.view(-1)
678
+ for i in range(self.num_experts):
679
+ expert = self.experts[i]
680
+ expert_output = expert(hidden_states[flat_topk_idx == i])
681
+ y[flat_topk_idx == i] = expert_output
682
 
683
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
 
684
 
685
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
686
+ return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
687
 
688
 
689
  # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMOE,Llama->Gemmoe