Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +16 -47
modeling_gemmoe.py
CHANGED
@@ -683,14 +683,7 @@ class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
|
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
685 |
"""
|
686 |
-
This implementation is
|
687 |
-
strictly equivalent to standard MoE with full capacity (no
|
688 |
-
dropped tokens). It's faster since it formulates MoE operations
|
689 |
-
in terms of block-sparse operations to accomodate imbalanced
|
690 |
-
assignments of tokens to experts, whereas standard MoE either
|
691 |
-
(1) drop tokens at the cost of reduced performance or (2) set
|
692 |
-
capacity factor to number of experts and thus waste computation
|
693 |
-
and memory on padding.
|
694 |
"""
|
695 |
|
696 |
def __init__(self, config):
|
@@ -706,51 +699,26 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
706 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
707 |
|
708 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
709 |
-
""" """
|
710 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
711 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
712 |
# router_logits: (batch * sequence_length, n_experts)
|
713 |
router_logits = self.gate(hidden_states)
|
714 |
|
715 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
)
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
for expert_idx in range(self.num_experts):
|
731 |
-
expert_layer = self.experts[expert_idx]
|
732 |
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
733 |
-
|
734 |
-
if top_x.shape[0] == 0:
|
735 |
-
for param in expert_layer.parameters():
|
736 |
-
if param.requires_grad:
|
737 |
-
param.grad = torch.zeros_like(param)
|
738 |
-
continue
|
739 |
-
|
740 |
-
# in torch it is faster to index using lists than torch tensors
|
741 |
-
top_x_list = top_x.tolist()
|
742 |
-
idx_list = idx.tolist()
|
743 |
-
|
744 |
-
# Index the correct hidden states and compute the expert hidden state for
|
745 |
-
# the current expert. We need to make sure to multiply the output hidden
|
746 |
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
747 |
-
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
748 |
-
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
749 |
-
|
750 |
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
751 |
-
# the `top_x` tensor here.
|
752 |
-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
753 |
-
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
754 |
return final_hidden_states, router_logits
|
755 |
|
756 |
|
@@ -1297,6 +1265,7 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
|
|
1297 |
hidden_states = outputs[0]
|
1298 |
logits = self.lm_head(hidden_states)
|
1299 |
logits = logits.float()
|
|
|
1300 |
if self.training:
|
1301 |
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1302 |
for param in expert.parameters():
|
|
|
683 |
|
684 |
class GemmoeSparseMoeBlock(nn.Module):
|
685 |
"""
|
686 |
+
This implementation is strictly equivalent to standard MoE with full capacity (no dropped tokens). It's faster since it formulates MoE operations in terms of block-sparse operations to accommodate imbalanced assignments of tokens to experts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
"""
|
688 |
|
689 |
def __init__(self, config):
|
|
|
699 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
700 |
|
701 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
702 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
703 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
704 |
# router_logits: (batch * sequence_length, n_experts)
|
705 |
router_logits = self.gate(hidden_states)
|
706 |
|
707 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
708 |
+
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
709 |
+
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
|
710 |
+
topk_weight = topk_weight.to(hidden_states.dtype)
|
711 |
+
|
712 |
+
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
713 |
+
y = torch.empty_like(hidden_states)
|
714 |
+
flat_topk_idx = topk_idx.view(-1)
|
715 |
+
for i in range(self.num_experts):
|
716 |
+
expert = self.experts[i]
|
717 |
+
mask = flat_topk_idx == i
|
718 |
+
if mask.any():
|
719 |
+
y[mask] = expert(hidden_states[mask])
|
720 |
+
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
721 |
+
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
return final_hidden_states, router_logits
|
723 |
|
724 |
|
|
|
1265 |
hidden_states = outputs[0]
|
1266 |
logits = self.lm_head(hidden_states)
|
1267 |
logits = logits.float()
|
1268 |
+
|
1269 |
if self.training:
|
1270 |
for expert in self.model.layers[-1].block_sparse_moe.experts:
|
1271 |
for param in expert.parameters():
|