Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +27 -90
modeling_gemmoe.py
CHANGED
@@ -167,10 +167,8 @@ class GemmoeRMSNorm(nn.Module):
|
|
167 |
self.weight = nn.Parameter(torch.zeros(dim))
|
168 |
|
169 |
def _norm(self, x):
|
170 |
-
|
171 |
-
|
172 |
-
mean = x_float.pow(2).mean(-1, keepdim=True)
|
173 |
-
normed_x = x_float * torch.rsqrt(mean + self.eps)
|
174 |
return normed_x
|
175 |
|
176 |
def forward(self, x):
|
@@ -191,16 +189,14 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
191 |
|
192 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
193 |
self.max_seq_len_cached = seq_len
|
194 |
-
freq_exponents = (2.0 / self.dim) * (
|
195 |
-
torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
|
196 |
-
)
|
197 |
timescale = self.base ** freq_exponents
|
198 |
-
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.
|
199 |
radians_new = positions[..., None] / timescale[None, None, :]
|
200 |
radians_new = radians_new.squeeze(0)
|
201 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
202 |
-
cos = emb.cos().to(device=device, non_blocking=True)
|
203 |
-
sin = emb.sin().to(device=device, non_blocking=True)
|
204 |
self.register_buffer("cos_cached", cos, persistent=False)
|
205 |
self.register_buffer("sin_cached", sin, persistent=False)
|
206 |
|
@@ -209,10 +205,7 @@ class GemmoeRotaryEmbedding(nn.Module):
|
|
209 |
seq_len = x.size(2)
|
210 |
if seq_len > self.max_seq_len_cached:
|
211 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
212 |
-
return
|
213 |
-
self.cos_cached[:seq_len],
|
214 |
-
self.sin_cached[:seq_len],
|
215 |
-
)
|
216 |
|
217 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
218 |
def rotate_half(x):
|
@@ -222,27 +215,7 @@ def rotate_half(x):
|
|
222 |
return torch.cat((-x2, x1), dim=-1)
|
223 |
|
224 |
|
225 |
-
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
226 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
227 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
228 |
-
|
229 |
-
Args:
|
230 |
-
q (`torch.Tensor`): The query tensor.
|
231 |
-
k (`torch.Tensor`): The key tensor.
|
232 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
233 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
234 |
-
position_ids (`torch.Tensor`, *optional*):
|
235 |
-
Deprecated and unused.
|
236 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
237 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
238 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
239 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
240 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
241 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
242 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
243 |
-
Returns:
|
244 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
245 |
-
"""
|
246 |
seq_len, dim = q.shape[-2], q.shape[-1]
|
247 |
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
248 |
sin = sin[:seq_len].view(1, 1, seq_len, dim)
|
@@ -250,7 +223,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
250 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
251 |
return q_embed, k_embed
|
252 |
|
253 |
-
|
254 |
|
255 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
256 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
@@ -662,7 +634,7 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
|
|
662 |
super().__init__()
|
663 |
self.ffn_dim = config.intermediate_size
|
664 |
self.hidden_dim = config.hidden_size
|
665 |
-
|
666 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
667 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
668 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
@@ -683,16 +655,13 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
683 |
self.num_experts = config.num_local_experts
|
684 |
self.top_k = 2
|
685 |
|
686 |
-
# gating
|
687 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
688 |
-
|
689 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
690 |
|
691 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
692 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
693 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
694 |
|
695 |
-
# router_logits: (batch * sequence_length, n_experts)
|
696 |
router_logits = self.gate(hidden_states)
|
697 |
routing_weights = F.softmax(router_logits, dim=1)
|
698 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
@@ -700,17 +669,18 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
700 |
|
701 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
702 |
|
703 |
-
|
704 |
-
|
705 |
-
flat_topk_idx = topk_idx.view(-1)
|
706 |
for i in range(self.num_experts):
|
707 |
-
|
708 |
-
expert_output =
|
709 |
-
|
710 |
-
|
711 |
-
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
712 |
|
713 |
-
|
|
|
|
|
|
|
|
|
|
|
714 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
715 |
|
716 |
|
@@ -719,12 +689,10 @@ class GemmoeDecoderLayer(nn.Module):
|
|
719 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
720 |
super().__init__()
|
721 |
self.hidden_size = config.hidden_size
|
722 |
-
|
723 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
724 |
-
|
725 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
726 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
727 |
-
self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
728 |
|
729 |
def forward(
|
730 |
self,
|
@@ -737,32 +705,9 @@ class GemmoeDecoderLayer(nn.Module):
|
|
737 |
use_cache: Optional[bool] = False,
|
738 |
**kwargs,
|
739 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
740 |
-
if "padding_mask" in kwargs:
|
741 |
-
warnings.warn(
|
742 |
-
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
743 |
-
)
|
744 |
-
"""
|
745 |
-
Args:
|
746 |
-
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
747 |
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
748 |
-
`(batch, sequence_length)` where padding elements are indicated by 0.
|
749 |
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
750 |
-
output_attentions (`bool`, *optional*):
|
751 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
752 |
-
returned tensors for more detail.
|
753 |
-
output_router_logits (`bool`, *optional*):
|
754 |
-
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
755 |
-
should not be returned during inference.
|
756 |
-
use_cache (`bool`, *optional*):
|
757 |
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
758 |
-
(see `past_key_values`).
|
759 |
-
"""
|
760 |
-
|
761 |
residual = hidden_states
|
762 |
-
|
763 |
hidden_states = self.input_layernorm(hidden_states)
|
764 |
|
765 |
-
# Self Attention
|
766 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
767 |
hidden_states=hidden_states,
|
768 |
attention_mask=attention_mask,
|
@@ -772,21 +717,17 @@ class GemmoeDecoderLayer(nn.Module):
|
|
772 |
use_cache=use_cache,
|
773 |
)
|
774 |
hidden_states = residual + hidden_states
|
775 |
-
|
776 |
-
# Fully Connected
|
777 |
residual = hidden_states
|
778 |
-
hidden_states = self.
|
779 |
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
780 |
hidden_states = residual + hidden_states
|
781 |
|
782 |
outputs = (hidden_states,)
|
783 |
-
|
784 |
if output_attentions:
|
785 |
outputs += (self_attn_weights,)
|
786 |
-
|
787 |
if use_cache:
|
788 |
outputs += (present_key_value,)
|
789 |
-
|
790 |
if output_router_logits:
|
791 |
outputs += (router_logits,)
|
792 |
|
@@ -1009,14 +950,6 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1009 |
if inputs_embeds is None:
|
1010 |
inputs_embeds = self.embed_tokens(input_ids)
|
1011 |
|
1012 |
-
# Scale embeddings
|
1013 |
-
# Fix for precision issue when casting to bfloat16
|
1014 |
-
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
1015 |
-
if inputs_embeds.dtype == torch.bfloat16:
|
1016 |
-
pass
|
1017 |
-
|
1018 |
-
hidden_states = inputs_embeds * hidden_size_sqrt
|
1019 |
-
|
1020 |
past_seen_tokens = 0
|
1021 |
if use_cache: # kept for BC (cache positions)
|
1022 |
if not isinstance(past_key_values, StaticCache):
|
@@ -1036,8 +969,12 @@ class GemmoeModel(GemmoePreTrainedModel):
|
|
1036 |
# embed positions
|
1037 |
hidden_states = inputs_embeds
|
1038 |
|
1039 |
-
#
|
1040 |
-
|
|
|
|
|
|
|
|
|
1041 |
|
1042 |
# decoder layers
|
1043 |
all_hidden_states = () if output_hidden_states else None
|
|
|
167 |
self.weight = nn.Parameter(torch.zeros(dim))
|
168 |
|
169 |
def _norm(self, x):
|
170 |
+
x_float = x.float()
|
171 |
+
normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
|
172 |
return normed_x
|
173 |
|
174 |
def forward(self, x):
|
|
|
189 |
|
190 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
191 |
self.max_seq_len_cached = seq_len
|
192 |
+
freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
|
|
|
|
|
193 |
timescale = self.base ** freq_exponents
|
194 |
+
positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
|
195 |
radians_new = positions[..., None] / timescale[None, None, :]
|
196 |
radians_new = radians_new.squeeze(0)
|
197 |
emb = torch.cat((radians_new, radians_new), dim=-1)
|
198 |
+
cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
|
199 |
+
sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
|
200 |
self.register_buffer("cos_cached", cos, persistent=False)
|
201 |
self.register_buffer("sin_cached", sin, persistent=False)
|
202 |
|
|
|
205 |
seq_len = x.size(2)
|
206 |
if seq_len > self.max_seq_len_cached:
|
207 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
208 |
+
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
|
|
|
|
|
|
|
209 |
|
210 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
211 |
def rotate_half(x):
|
|
|
215 |
return torch.cat((-x2, x1), dim=-1)
|
216 |
|
217 |
|
|
|
218 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
seq_len, dim = q.shape[-2], q.shape[-1]
|
220 |
cos = cos[:seq_len].view(1, 1, seq_len, dim)
|
221 |
sin = sin[:seq_len].view(1, 1, seq_len, dim)
|
|
|
223 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
224 |
return q_embed, k_embed
|
225 |
|
|
|
226 |
|
227 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
228 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
634 |
super().__init__()
|
635 |
self.ffn_dim = config.intermediate_size
|
636 |
self.hidden_dim = config.hidden_size
|
637 |
+
|
638 |
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
639 |
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
|
640 |
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
|
|
|
655 |
self.num_experts = config.num_local_experts
|
656 |
self.top_k = 2
|
657 |
|
|
|
658 |
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
659 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
660 |
|
661 |
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
662 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
663 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
664 |
|
|
|
665 |
router_logits = self.gate(hidden_states)
|
666 |
routing_weights = F.softmax(router_logits, dim=1)
|
667 |
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
|
|
|
669 |
|
670 |
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
|
671 |
|
672 |
+
expert_outputs = []
|
|
|
|
|
673 |
for i in range(self.num_experts):
|
674 |
+
expert_input = hidden_states[topk_idx[:, i]]
|
675 |
+
expert_output = self.experts[i](expert_input)
|
676 |
+
expert_outputs.append(expert_output)
|
|
|
|
|
677 |
|
678 |
+
expert_outputs = torch.stack(expert_outputs, dim=1)
|
679 |
+
expert_outputs = expert_outputs.view(batch_size * sequence_length, self.top_k, -1)
|
680 |
+
|
681 |
+
final_hidden_states = torch.einsum("bke,bkd->bed", topk_weight, expert_outputs)
|
682 |
+
final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
|
683 |
+
|
684 |
return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
|
685 |
|
686 |
|
|
|
689 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
690 |
super().__init__()
|
691 |
self.hidden_size = config.hidden_size
|
692 |
+
|
693 |
self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
|
694 |
self.block_sparse_moe = GemmoeSparseMoeBlock(config)
|
695 |
self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
696 |
|
697 |
def forward(
|
698 |
self,
|
|
|
705 |
use_cache: Optional[bool] = False,
|
706 |
**kwargs,
|
707 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
residual = hidden_states
|
|
|
709 |
hidden_states = self.input_layernorm(hidden_states)
|
710 |
|
|
|
711 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
712 |
hidden_states=hidden_states,
|
713 |
attention_mask=attention_mask,
|
|
|
717 |
use_cache=use_cache,
|
718 |
)
|
719 |
hidden_states = residual + hidden_states
|
720 |
+
|
|
|
721 |
residual = hidden_states
|
722 |
+
hidden_states = self.input_layernorm(hidden_states)
|
723 |
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
724 |
hidden_states = residual + hidden_states
|
725 |
|
726 |
outputs = (hidden_states,)
|
|
|
727 |
if output_attentions:
|
728 |
outputs += (self_attn_weights,)
|
|
|
729 |
if use_cache:
|
730 |
outputs += (present_key_value,)
|
|
|
731 |
if output_router_logits:
|
732 |
outputs += (router_logits,)
|
733 |
|
|
|
950 |
if inputs_embeds is None:
|
951 |
inputs_embeds = self.embed_tokens(input_ids)
|
952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
past_seen_tokens = 0
|
954 |
if use_cache: # kept for BC (cache positions)
|
955 |
if not isinstance(past_key_values, StaticCache):
|
|
|
969 |
# embed positions
|
970 |
hidden_states = inputs_embeds
|
971 |
|
972 |
+
# Scale embeddings
|
973 |
+
hidden_size_sqrt = math.sqrt(self.config.hidden_size)
|
974 |
+
if inputs_embeds.dtype == torch.bfloat16:
|
975 |
+
hidden_states = inputs_embeds * hidden_size_sqrt
|
976 |
+
else:
|
977 |
+
hidden_states = inputs_embeds * hidden_size_sqrt
|
978 |
|
979 |
# decoder layers
|
980 |
all_hidden_states = () if output_hidden_states else None
|