Crystalcareai commited on
Commit
0f57763
·
verified ·
1 Parent(s): 95f692f

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +25 -4
modeling_gemmoe.py CHANGED
@@ -682,6 +682,17 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
682
  super().__init__(*args, **kwargs)
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
685
  def __init__(self, config):
686
  super().__init__()
687
  self.hidden_dim = config.hidden_size
@@ -689,45 +700,56 @@ class GemmoeSparseMoeBlock(nn.Module):
689
  self.num_experts = config.num_local_experts
690
  self.top_k = config.num_experts_per_tok
691
 
 
692
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
 
693
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
694
 
695
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
696
  batch_size, sequence_length, hidden_dim = hidden_states.shape
697
  hidden_states = hidden_states.view(-1, hidden_dim)
698
-
699
  router_logits = self.gate(hidden_states)
700
 
701
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
702
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
703
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
 
704
  routing_weights = routing_weights.to(hidden_states.dtype)
705
 
706
  final_hidden_states = torch.zeros(
707
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
708
  )
709
 
 
 
710
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
711
 
 
712
  for expert_idx in range(self.num_experts):
713
  expert_layer = self.experts[expert_idx]
714
  idx, top_x = torch.where(expert_mask[expert_idx])
715
 
716
  if top_x.shape[0] == 0:
717
- # Handle unused parameters for this expert
718
  for param in expert_layer.parameters():
719
  if param.requires_grad:
720
  param.grad = torch.zeros_like(param)
721
  continue
722
 
 
723
  top_x_list = top_x.tolist()
724
  idx_list = idx.tolist()
725
 
 
 
 
726
  current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
727
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
728
 
 
 
729
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
730
-
731
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
732
  return final_hidden_states, router_logits
733
 
@@ -1275,7 +1297,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1275
  hidden_states = outputs[0]
1276
  logits = self.lm_head(hidden_states)
1277
  logits = logits.float()
1278
-
1279
  if self.training:
1280
  for expert in self.model.layers[-1].block_sparse_moe.experts:
1281
  for param in expert.parameters():
 
682
  super().__init__(*args, **kwargs)
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
685
+ """
686
+ This implementation is
687
+ strictly equivalent to standard MoE with full capacity (no
688
+ dropped tokens). It's faster since it formulates MoE operations
689
+ in terms of block-sparse operations to accomodate imbalanced
690
+ assignments of tokens to experts, whereas standard MoE either
691
+ (1) drop tokens at the cost of reduced performance or (2) set
692
+ capacity factor to number of experts and thus waste computation
693
+ and memory on padding.
694
+ """
695
+
696
  def __init__(self, config):
697
  super().__init__()
698
  self.hidden_dim = config.hidden_size
 
700
  self.num_experts = config.num_local_experts
701
  self.top_k = config.num_experts_per_tok
702
 
703
+ # gating
704
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
705
+
706
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
707
 
708
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
709
+ """ """
710
  batch_size, sequence_length, hidden_dim = hidden_states.shape
711
  hidden_states = hidden_states.view(-1, hidden_dim)
712
+ # router_logits: (batch * sequence_length, n_experts)
713
  router_logits = self.gate(hidden_states)
714
 
715
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
716
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
717
  routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
718
+ # we cast back to the input dtype
719
  routing_weights = routing_weights.to(hidden_states.dtype)
720
 
721
  final_hidden_states = torch.zeros(
722
  (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
723
  )
724
 
725
+ # One hot encode the selected experts to create an expert mask
726
+ # this will be used to easily index which expert is going to be sollicitated
727
  expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
728
 
729
+ # Loop over all available experts in the model and perform the computation on each expert
730
  for expert_idx in range(self.num_experts):
731
  expert_layer = self.experts[expert_idx]
732
  idx, top_x = torch.where(expert_mask[expert_idx])
733
 
734
  if top_x.shape[0] == 0:
 
735
  for param in expert_layer.parameters():
736
  if param.requires_grad:
737
  param.grad = torch.zeros_like(param)
738
  continue
739
 
740
+ # in torch it is faster to index using lists than torch tensors
741
  top_x_list = top_x.tolist()
742
  idx_list = idx.tolist()
743
 
744
+ # Index the correct hidden states and compute the expert hidden state for
745
+ # the current expert. We need to make sure to multiply the output hidden
746
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
747
  current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
748
  current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
749
 
750
+ # However `index_add_` only support torch tensors for indexing so we'll use
751
+ # the `top_x` tensor here.
752
  final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
 
753
  final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
754
  return final_hidden_states, router_logits
755
 
 
1297
  hidden_states = outputs[0]
1298
  logits = self.lm_head(hidden_states)
1299
  logits = logits.float()
 
1300
  if self.training:
1301
  for expert in self.model.layers[-1].block_sparse_moe.experts:
1302
  for param in expert.parameters():