Crystalcareai commited on
Commit
cb8455c
·
verified ·
1 Parent(s): 0f57763

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +16 -47
modeling_gemmoe.py CHANGED
@@ -683,14 +683,7 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
685
  """
686
- This implementation is
687
- strictly equivalent to standard MoE with full capacity (no
688
- dropped tokens). It's faster since it formulates MoE operations
689
- in terms of block-sparse operations to accomodate imbalanced
690
- assignments of tokens to experts, whereas standard MoE either
691
- (1) drop tokens at the cost of reduced performance or (2) set
692
- capacity factor to number of experts and thus waste computation
693
- and memory on padding.
694
  """
695
 
696
  def __init__(self, config):
@@ -706,51 +699,26 @@ class GemmoeSparseMoeBlock(nn.Module):
706
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
707
 
708
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
709
- """ """
710
  batch_size, sequence_length, hidden_dim = hidden_states.shape
711
  hidden_states = hidden_states.view(-1, hidden_dim)
712
  # router_logits: (batch * sequence_length, n_experts)
713
  router_logits = self.gate(hidden_states)
714
 
715
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
716
- routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
717
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
718
- # we cast back to the input dtype
719
- routing_weights = routing_weights.to(hidden_states.dtype)
720
-
721
- final_hidden_states = torch.zeros(
722
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
723
- )
724
-
725
- # One hot encode the selected experts to create an expert mask
726
- # this will be used to easily index which expert is going to be sollicitated
727
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
728
-
729
- # Loop over all available experts in the model and perform the computation on each expert
730
- for expert_idx in range(self.num_experts):
731
- expert_layer = self.experts[expert_idx]
732
- idx, top_x = torch.where(expert_mask[expert_idx])
733
-
734
- if top_x.shape[0] == 0:
735
- for param in expert_layer.parameters():
736
- if param.requires_grad:
737
- param.grad = torch.zeros_like(param)
738
- continue
739
-
740
- # in torch it is faster to index using lists than torch tensors
741
- top_x_list = top_x.tolist()
742
- idx_list = idx.tolist()
743
-
744
- # Index the correct hidden states and compute the expert hidden state for
745
- # the current expert. We need to make sure to multiply the output hidden
746
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
747
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
748
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
749
-
750
- # However `index_add_` only support torch tensors for indexing so we'll use
751
- # the `top_x` tensor here.
752
- final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
753
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
754
  return final_hidden_states, router_logits
755
 
756
 
@@ -1297,6 +1265,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1297
  hidden_states = outputs[0]
1298
  logits = self.lm_head(hidden_states)
1299
  logits = logits.float()
 
1300
  if self.training:
1301
  for expert in self.model.layers[-1].block_sparse_moe.experts:
1302
  for param in expert.parameters():
 
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
685
  """
686
+ This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accommodate imbalanced assignments of tokens to experts.
 
 
 
 
 
 
 
687
  """
688
 
689
  def __init__(self, config):
 
699
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
700
 
701
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
702
  batch_size, sequence_length, hidden_dim = hidden_states.shape
703
  hidden_states = hidden_states.view(-1, hidden_dim)
704
  # router_logits: (batch * sequence_length, n_experts)
705
  router_logits = self.gate(hidden_states)
706
 
707
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
708
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
709
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
710
+ topk_weight = topk_weight.to(hidden_states.dtype)
711
+
712
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
713
+ y = torch.empty_like(hidden_states)
714
+ flat_topk_idx = topk_idx.view(-1)
715
+ for i in range(self.num_experts):
716
+ expert = self.experts[i]
717
+ mask = flat_topk_idx == i
718
+ if mask.any():
719
+ y[mask] = expert(hidden_states[mask])
720
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
721
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  return final_hidden_states, router_logits
723
 
724
 
 
1265
  hidden_states = outputs[0]
1266
  logits = self.lm_head(hidden_states)
1267
  logits = logits.float()
1268
+
1269
  if self.training:
1270
  for expert in self.model.layers[-1].block_sparse_moe.experts:
1271
  for param in expert.parameters():