Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +25 -4
modeling_gemmoe.py
CHANGED
@@ -682,6 +682,17 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
|
|
682 |
super().__init__(*args, **kwargs)
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
def __init__(self, config):
|
686 |
super().__init__()
|
687 |
self.hidden_dim = config.hidden_size
|
@@ -689,45 +700,56 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
689 |
self.num_experts = config.num_local_experts
|
690 |
self.top_k = config.num_experts_per_tok
|
691 |
|
|
|
692 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
693 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
694 |
|
695 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
696 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
697 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
698 |
-
|
699 |
router_logits = self.gate(hidden_states)
|
700 |
|
701 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
702 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
703 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
|
704 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
705 |
|
706 |
final_hidden_states = torch.zeros(
|
707 |
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
708 |
)
|
709 |
|
|
|
|
|
710 |
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
711 |
|
|
|
712 |
for expert_idx in range(self.num_experts):
|
713 |
expert_layer = self.experts[expert_idx]
|
714 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
715 |
|
716 |
if top_x.shape[0] == 0:
|
717 |
-
# Handle unused parameters for this expert
|
718 |
for param in expert_layer.parameters():
|
719 |
if param.requires_grad:
|
720 |
param.grad = torch.zeros_like(param)
|
721 |
continue
|
722 |
|
|
|
723 |
top_x_list = top_x.tolist()
|
724 |
idx_list = idx.tolist()
|
725 |
|
|
|
|
|
|
|
726 |
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
727 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
728 |
|
|
|
|
|
729 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
730 |
-
|
731 |
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
732 |
return final_hidden_states, router_logits
|
733 |
|
@@ -1275,7 +1297,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1275 |
hidden_states = outputs[0]
|
1276 |
logits = self.lm_head(hidden_states)
|
1277 |
logits = logits.float()
|
1278 |
-
|
1279 |
if self.training:
|
1280 |
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1281 |
for param in expert.parameters():
|
|
|
682 |
super().__init__(*args, **kwargs)
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
685 |
+
"""
|
686 |
+
This implementation is
|
687 |
+
strictly equivalent to standard MoE with full capacity (no
|
688 |
+
dropped tokens). It's faster since it formulates MoE operations
|
689 |
+
in terms of block-sparse operations to accomodate imbalanced
|
690 |
+
assignments of tokens to experts, whereas standard MoE either
|
691 |
+
(1) drop tokens at the cost of reduced performance or (2) set
|
692 |
+
capacity factor to number of experts and thus waste computation
|
693 |
+
and memory on padding.
|
694 |
+
"""
|
695 |
+
|
696 |
def __init__(self, config):
|
697 |
super().__init__()
|
698 |
self.hidden_dim = config.hidden_size
|
|
|
700 |
self.num_experts = config.num_local_experts
|
701 |
self.top_k = config.num_experts_per_tok
|
702 |
|
703 |
+
# gating
|
704 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
705 |
+
|
706 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
707 |
|
708 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
709 |
+
""" """
|
710 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
711 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
712 |
+
# router_logits: (batch * sequence_length, n_experts)
|
713 |
router_logits = self.gate(hidden_states)
|
714 |
|
715 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
716 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
717 |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
718 |
+
# we cast back to the input dtype
|
719 |
routing_weights = routing_weights.to(hidden_states.dtype)
|
720 |
|
721 |
final_hidden_states = torch.zeros(
|
722 |
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
723 |
)
|
724 |
|
725 |
+
# One hot encode the selected experts to create an expert mask
|
726 |
+
# this will be used to easily index which expert is going to be sollicitated
|
727 |
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
728 |
|
729 |
+
# Loop over all available experts in the model and perform the computation on each expert
|
730 |
for expert_idx in range(self.num_experts):
|
731 |
expert_layer = self.experts[expert_idx]
|
732 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
733 |
|
734 |
if top_x.shape[0] == 0:
|
|
|
735 |
for param in expert_layer.parameters():
|
736 |
if param.requires_grad:
|
737 |
param.grad = torch.zeros_like(param)
|
738 |
continue
|
739 |
|
740 |
+
# in torch it is faster to index using lists than torch tensors
|
741 |
top_x_list = top_x.tolist()
|
742 |
idx_list = idx.tolist()
|
743 |
|
744 |
+
# Index the correct hidden states and compute the expert hidden state for
|
745 |
+
# the current expert. We need to make sure to multiply the output hidden
|
746 |
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
747 |
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
748 |
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
749 |
|
750 |
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
751 |
+
# the `top_x` tensor here.
|
752 |
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
|
|
753 |
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
754 |
return final_hidden_states, router_logits
|
755 |
|
|
|
1297 |
hidden_states = outputs[0]
|
1298 |
logits = self.lm_head(hidden_states)
|
1299 |
logits = logits.float()
|
|
|
1300 |
if self.training:
|
1301 |
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1302 |
for param in expert.parameters():
|