Crystalcareai commited on
Commit
892da81
·
verified ·
1 Parent(s): a0d5586

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +27 -90
modeling_gemmoe.py CHANGED
@@ -167,10 +167,8 @@ class GemmoeRMSNorm(nn.Module):
167
  self.weight = nn.Parameter(torch.zeros(dim))
168
 
169
  def _norm(self, x):
170
- # Ensure the entire normalization is done in float32
171
- x_float = x.float() # upcast to float32
172
- mean = x_float.pow(2).mean(-1, keepdim=True)
173
- normed_x = x_float * torch.rsqrt(mean + self.eps)
174
  return normed_x
175
 
176
  def forward(self, x):
@@ -191,16 +189,14 @@ class GemmoeRotaryEmbedding(nn.Module):
191
 
192
  def _set_cos_sin_cache(self, seq_len, device, dtype):
193
  self.max_seq_len_cached = seq_len
194
- freq_exponents = (2.0 / self.dim) * (
195
- torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
196
- )
197
  timescale = self.base ** freq_exponents
198
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
199
  radians_new = positions[..., None] / timescale[None, None, :]
200
  radians_new = radians_new.squeeze(0)
201
  emb = torch.cat((radians_new, radians_new), dim=-1)
202
- cos = emb.cos().to(device=device, non_blocking=True)
203
- sin = emb.sin().to(device=device, non_blocking=True)
204
  self.register_buffer("cos_cached", cos, persistent=False)
205
  self.register_buffer("sin_cached", sin, persistent=False)
206
 
@@ -209,10 +205,7 @@ class GemmoeRotaryEmbedding(nn.Module):
209
  seq_len = x.size(2)
210
  if seq_len > self.max_seq_len_cached:
211
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
212
- return (
213
- self.cos_cached[:seq_len],
214
- self.sin_cached[:seq_len],
215
- )
216
 
217
  # Copied from transformers.models.llama.modeling_llama.rotate_half
218
  def rotate_half(x):
@@ -222,27 +215,7 @@ def rotate_half(x):
222
  return torch.cat((-x2, x1), dim=-1)
223
 
224
 
225
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
226
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
227
- """Applies Rotary Position Embedding to the query and key tensors.
228
-
229
- Args:
230
- q (`torch.Tensor`): The query tensor.
231
- k (`torch.Tensor`): The key tensor.
232
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
233
- sin (`torch.Tensor`): The sine part of the rotary embedding.
234
- position_ids (`torch.Tensor`, *optional*):
235
- Deprecated and unused.
236
- unsqueeze_dim (`int`, *optional*, defaults to 1):
237
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
238
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
239
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
240
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
241
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
242
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
243
- Returns:
244
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
245
- """
246
  seq_len, dim = q.shape[-2], q.shape[-1]
247
  cos = cos[:seq_len].view(1, 1, seq_len, dim)
248
  sin = sin[:seq_len].view(1, 1, seq_len, dim)
@@ -250,7 +223,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
250
  k_embed = (k * cos) + (rotate_half(k) * sin)
251
  return q_embed, k_embed
252
 
253
-
254
 
255
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
256
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -662,7 +634,7 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
662
  super().__init__()
663
  self.ffn_dim = config.intermediate_size
664
  self.hidden_dim = config.hidden_size
665
-
666
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
667
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
668
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
@@ -683,16 +655,13 @@ class GemmoeSparseMoeBlock(nn.Module):
683
  self.num_experts = config.num_local_experts
684
  self.top_k = 2
685
 
686
- # gating
687
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
688
-
689
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
690
 
691
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
692
  batch_size, sequence_length, hidden_dim = hidden_states.shape
693
  hidden_states = hidden_states.view(-1, hidden_dim)
694
 
695
- # router_logits: (batch * sequence_length, n_experts)
696
  router_logits = self.gate(hidden_states)
697
  routing_weights = F.softmax(router_logits, dim=1)
698
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
@@ -700,17 +669,18 @@ class GemmoeSparseMoeBlock(nn.Module):
700
 
701
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
702
 
703
- y = torch.empty_like(hidden_states)
704
-
705
- flat_topk_idx = topk_idx.view(-1)
706
  for i in range(self.num_experts):
707
- expert = self.experts[i]
708
- expert_output = expert(hidden_states[flat_topk_idx == i])
709
- y[flat_topk_idx == i] = expert_output
710
-
711
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
712
 
713
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
 
 
 
 
 
714
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
715
 
716
 
@@ -719,12 +689,10 @@ class GemmoeDecoderLayer(nn.Module):
719
  def __init__(self, config: GemmoeConfig, layer_idx: int):
720
  super().__init__()
721
  self.hidden_size = config.hidden_size
722
-
723
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
724
-
725
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
726
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
727
- self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
728
 
729
  def forward(
730
  self,
@@ -737,32 +705,9 @@ class GemmoeDecoderLayer(nn.Module):
737
  use_cache: Optional[bool] = False,
738
  **kwargs,
739
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
740
- if "padding_mask" in kwargs:
741
- warnings.warn(
742
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
743
- )
744
- """
745
- Args:
746
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
747
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
748
- `(batch, sequence_length)` where padding elements are indicated by 0.
749
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
750
- output_attentions (`bool`, *optional*):
751
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
752
- returned tensors for more detail.
753
- output_router_logits (`bool`, *optional*):
754
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
755
- should not be returned during inference.
756
- use_cache (`bool`, *optional*):
757
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
758
- (see `past_key_values`).
759
- """
760
-
761
  residual = hidden_states
762
-
763
  hidden_states = self.input_layernorm(hidden_states)
764
 
765
- # Self Attention
766
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
767
  hidden_states=hidden_states,
768
  attention_mask=attention_mask,
@@ -772,21 +717,17 @@ class GemmoeDecoderLayer(nn.Module):
772
  use_cache=use_cache,
773
  )
774
  hidden_states = residual + hidden_states
775
-
776
- # Fully Connected
777
  residual = hidden_states
778
- hidden_states = self.post_attention_layernorm(hidden_states)
779
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
780
  hidden_states = residual + hidden_states
781
 
782
  outputs = (hidden_states,)
783
-
784
  if output_attentions:
785
  outputs += (self_attn_weights,)
786
-
787
  if use_cache:
788
  outputs += (present_key_value,)
789
-
790
  if output_router_logits:
791
  outputs += (router_logits,)
792
 
@@ -1009,14 +950,6 @@ class GemmoeModel(GemmoePreTrainedModel):
1009
  if inputs_embeds is None:
1010
  inputs_embeds = self.embed_tokens(input_ids)
1011
 
1012
- # Scale embeddings
1013
- # Fix for precision issue when casting to bfloat16
1014
- hidden_size_sqrt = math.sqrt(self.config.hidden_size)
1015
- if inputs_embeds.dtype == torch.bfloat16:
1016
- pass
1017
-
1018
- hidden_states = inputs_embeds * hidden_size_sqrt
1019
-
1020
  past_seen_tokens = 0
1021
  if use_cache: # kept for BC (cache positions)
1022
  if not isinstance(past_key_values, StaticCache):
@@ -1036,8 +969,12 @@ class GemmoeModel(GemmoePreTrainedModel):
1036
  # embed positions
1037
  hidden_states = inputs_embeds
1038
 
1039
- # normalized
1040
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
 
 
 
 
1041
 
1042
  # decoder layers
1043
  all_hidden_states = () if output_hidden_states else None
 
167
  self.weight = nn.Parameter(torch.zeros(dim))
168
 
169
  def _norm(self, x):
170
+ x_float = x.float()
171
+ normed_x = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
 
 
172
  return normed_x
173
 
174
  def forward(self, x):
 
189
 
190
  def _set_cos_sin_cache(self, seq_len, device, dtype):
191
  self.max_seq_len_cached = seq_len
192
+ freq_exponents = (2.0 / self.dim) * (torch.arange(self.dim // 2, dtype=torch.float32, device="cpu").float())
 
 
193
  timescale = self.base ** freq_exponents
194
+ positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.float32).float()
195
  radians_new = positions[..., None] / timescale[None, None, :]
196
  radians_new = radians_new.squeeze(0)
197
  emb = torch.cat((radians_new, radians_new), dim=-1)
198
+ cos = emb.cos().to(device=device, dtype=dtype, non_blocking=True)
199
+ sin = emb.sin().to(device=device, dtype=dtype, non_blocking=True)
200
  self.register_buffer("cos_cached", cos, persistent=False)
201
  self.register_buffer("sin_cached", sin, persistent=False)
202
 
 
205
  seq_len = x.size(2)
206
  if seq_len > self.max_seq_len_cached:
207
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
208
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
 
 
 
209
 
210
  # Copied from transformers.models.llama.modeling_llama.rotate_half
211
  def rotate_half(x):
 
215
  return torch.cat((-x2, x1), dim=-1)
216
 
217
 
 
218
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  seq_len, dim = q.shape[-2], q.shape[-1]
220
  cos = cos[:seq_len].view(1, 1, seq_len, dim)
221
  sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
223
  k_embed = (k * cos) + (rotate_half(k) * sin)
224
  return q_embed, k_embed
225
 
 
226
 
227
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
228
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
634
  super().__init__()
635
  self.ffn_dim = config.intermediate_size
636
  self.hidden_dim = config.hidden_size
637
+
638
  self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
639
  self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
640
  self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
 
655
  self.num_experts = config.num_local_experts
656
  self.top_k = 2
657
 
 
658
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
 
659
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
660
 
661
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
662
  batch_size, sequence_length, hidden_dim = hidden_states.shape
663
  hidden_states = hidden_states.view(-1, hidden_dim)
664
 
 
665
  router_logits = self.gate(hidden_states)
666
  routing_weights = F.softmax(router_logits, dim=1)
667
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
 
669
 
670
  hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
671
 
672
+ expert_outputs = []
 
 
673
  for i in range(self.num_experts):
674
+ expert_input = hidden_states[topk_idx[:, i]]
675
+ expert_output = self.experts[i](expert_input)
676
+ expert_outputs.append(expert_output)
 
 
677
 
678
+ expert_outputs = torch.stack(expert_outputs, dim=1)
679
+ expert_outputs = expert_outputs.view(batch_size * sequence_length, self.top_k, -1)
680
+
681
+ final_hidden_states = torch.einsum("bke,bkd->bed", topk_weight, expert_outputs)
682
+ final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
683
+
684
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
685
 
686
 
 
689
  def __init__(self, config: GemmoeConfig, layer_idx: int):
690
  super().__init__()
691
  self.hidden_size = config.hidden_size
692
+
693
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
694
  self.block_sparse_moe = GemmoeSparseMoeBlock(config)
695
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
696
 
697
  def forward(
698
  self,
 
705
  use_cache: Optional[bool] = False,
706
  **kwargs,
707
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  residual = hidden_states
 
709
  hidden_states = self.input_layernorm(hidden_states)
710
 
 
711
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
712
  hidden_states=hidden_states,
713
  attention_mask=attention_mask,
 
717
  use_cache=use_cache,
718
  )
719
  hidden_states = residual + hidden_states
720
+
 
721
  residual = hidden_states
722
+ hidden_states = self.input_layernorm(hidden_states)
723
  hidden_states, router_logits = self.block_sparse_moe(hidden_states)
724
  hidden_states = residual + hidden_states
725
 
726
  outputs = (hidden_states,)
 
727
  if output_attentions:
728
  outputs += (self_attn_weights,)
 
729
  if use_cache:
730
  outputs += (present_key_value,)
 
731
  if output_router_logits:
732
  outputs += (router_logits,)
733
 
 
950
  if inputs_embeds is None:
951
  inputs_embeds = self.embed_tokens(input_ids)
952
 
 
 
 
 
 
 
 
 
953
  past_seen_tokens = 0
954
  if use_cache: # kept for BC (cache positions)
955
  if not isinstance(past_key_values, StaticCache):
 
969
  # embed positions
970
  hidden_states = inputs_embeds
971
 
972
+ # Scale embeddings
973
+ hidden_size_sqrt = math.sqrt(self.config.hidden_size)
974
+ if inputs_embeds.dtype == torch.bfloat16:
975
+ hidden_states = inputs_embeds * hidden_size_sqrt
976
+ else:
977
+ hidden_states = inputs_embeds * hidden_size_sqrt
978
 
979
  # decoder layers
980
  all_hidden_states = () if output_hidden_states else None