Crystalcareai commited on
Commit
ec7309a
·
verified ·
1 Parent(s): 54b48ed

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +753 -603
modeling_gemmoe.py CHANGED
@@ -26,11 +26,14 @@ from torch import nn
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
  from transformers.activations import ACT2FN
29
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
  from transformers.modeling_attn_mask_utils import (
 
 
31
  _prepare_4d_causal_attention_mask,
 
32
  )
33
- from transformers.modeling_outputs import SequenceClassifierOutputWithPast, MoeModelOutputWithPast, MoeCausalLMOutputWithPast
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
36
  from transformers.utils import (
@@ -60,7 +63,6 @@ if is_torch_fx_available():
60
 
61
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
62
 
63
-
64
  logger = logging.get_logger(__name__)
65
 
66
  _CONFIG_FOR_DOC = "GemmoeConfig"
@@ -156,55 +158,121 @@ def _get_unpad_data(attention_mask):
156
  max_seqlen_in_batch,
157
  )
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  class GemmoeRMSNorm(nn.Module):
162
- def __init__(self, dim: int, eps: float = 1e-6):
 
 
 
163
  super().__init__()
164
- self.eps = eps
165
- self.weight = nn.Parameter(torch.zeros(dim))
166
-
167
- def _norm(self, x):
168
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
169
 
170
- def forward(self, x):
171
- output = self._norm(x.float()).type_as(x)
172
- return output * (self.weight + 1)
 
 
 
173
 
174
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
175
 
176
  class GemmoeRotaryEmbedding(nn.Module):
177
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
178
  super().__init__()
 
179
  self.dim = dim
180
  self.max_position_embeddings = max_position_embeddings
181
  self.base = base
182
- self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
 
 
 
 
 
 
 
 
183
 
184
  def _set_cos_sin_cache(self, seq_len, device, dtype):
185
  self.max_seq_len_cached = seq_len
186
- freq_exponents = (2.0 / self.dim) * (
187
- torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
188
- )
189
- timescale = self.base ** freq_exponents
190
- positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
191
- radians_new = positions[..., None] / timescale[None, None, :]
192
- radians_new = radians_new.squeeze(0)
193
- emb = torch.cat((radians_new, radians_new), dim=-1)
194
- cos = emb.cos().to(device=device, non_blocking=True)
195
- sin = emb.sin().to(device=device, non_blocking=True)
196
- self.register_buffer("cos_cached", cos, persistent=False)
197
- self.register_buffer("sin_cached", sin, persistent=False)
198
-
199
- def forward(self, x, position_ids=None, seq_len=None):
200
- if seq_len is None:
201
- seq_len = x.size(2)
202
- if seq_len > self.max_seq_len_cached:
203
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
204
  return (
205
- self.cos_cached[:seq_len],
206
- self.sin_cached[:seq_len],
207
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def rotate_half(x):
210
  """Rotates half the hidden dims of the input."""
@@ -212,16 +280,199 @@ def rotate_half(x):
212
  x2 = x[..., x.shape[-1] // 2 :]
213
  return torch.cat((-x2, x1), dim=-1)
214
 
215
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
216
- """Applies Rotary Position Embedding to the query and key tensors."""
217
- seq_len, dim = q.shape[-2], q.shape[-1]
218
- cos = cos[:seq_len].view(1, 1, seq_len, dim)
219
- sin = sin[:seq_len].view(1, 1, seq_len, dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  q_embed = (q * cos) + (rotate_half(q) * sin)
221
  k_embed = (k * cos) + (rotate_half(k) * sin)
222
  return q_embed, k_embed
223
 
224
- def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  """
226
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
227
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
@@ -231,15 +482,10 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
231
  return hidden_states
232
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
233
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 
234
 
235
  class GemmoeAttention(nn.Module):
236
- """
237
- Multi-headed attention module for Gemmoe model.
238
-
239
- Args:
240
- config (GemmoeConfig): The configuration object for the Gemmoe model.
241
- layer_idx (Optional[int]): The index of the layer. Default is None.
242
- """
243
 
244
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
245
  super().__init__()
@@ -247,34 +493,62 @@ class GemmoeAttention(nn.Module):
247
  self.layer_idx = layer_idx
248
  if layer_idx is None:
249
  logger.warning_once(
250
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
251
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
252
  "when creating this class."
253
  )
 
254
  self.attention_dropout = config.attention_dropout
255
  self.hidden_size = config.hidden_size
256
  self.num_heads = config.num_attention_heads
257
- self.head_dim = config.head_dim
258
  self.num_key_value_heads = config.num_key_value_heads
259
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
260
  self.max_position_embeddings = config.max_position_embeddings
261
  self.rope_theta = config.rope_theta
262
  self.is_causal = True
263
 
264
- if self.hidden_size % self.num_heads != 0:
265
  raise ValueError(
266
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
267
  f" and `num_heads`: {self.num_heads})."
268
  )
 
269
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
270
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
271
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
272
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
273
- self.rotary_emb = GemmoeRotaryEmbedding(
274
- self.head_dim,
275
- max_position_embeddings=self.max_position_embeddings,
276
- base=self.rope_theta,
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def forward(
280
  self,
@@ -284,64 +558,78 @@ class GemmoeAttention(nn.Module):
284
  past_key_value: Optional[Cache] = None,
285
  output_attentions: bool = False,
286
  use_cache: bool = False,
287
- cache_position: Optional[torch.LongTensor] = None,
288
  **kwargs,
289
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
- """
291
- Forward pass of the attention module.
292
-
293
- Args:
294
- hidden_states (torch.Tensor): The input hidden states.
295
- attention_mask (Optional[torch.Tensor]): The attention mask. Default is None.
296
- position_ids (Optional[torch.LongTensor]): The position IDs. Default is None.
297
- past_key_value (Optional[Cache]): The past key-value cache. Default is None.
298
- output_attentions (bool): Whether to output the attention weights. Default is False.
299
- use_cache (bool): Whether to use caching. Default is False.
300
- cache_position (Optional[torch.LongTensor]): The cache position. Default is None.
301
- **kwargs: Additional keyword arguments.
302
 
303
- Returns:
304
- Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
305
- - The output hidden states.
306
- - The attention weights (if `output_attentions=True`).
307
- - The past key-value cache (if `use_cache=True`).
308
- """
309
  bsz, q_len, _ = hidden_states.size()
310
 
311
- query_states = self.q_proj(hidden_states)
312
- key_states = self.k_proj(hidden_states)
313
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
316
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
 
319
- past_key_value = getattr(self, "past_key_value", past_key_value)
320
-
321
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
322
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
 
 
 
323
 
324
  if past_key_value is not None:
325
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
326
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
327
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
328
 
329
- key_states = self.repeat_kv(key_states, self.num_key_value_groups)
330
- value_states = self.repeat_kv(value_states, self.num_key_value_groups)
331
 
332
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
 
334
- if attention_mask is not None: # no matter the length, we just slice it
335
- if cache_position is not None:
336
- causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
337
- else:
338
- causal_mask = attention_mask
339
- attn_weights = attn_weights + causal_mask
 
 
 
 
 
 
340
 
341
  # upcast attention to fp32
342
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
343
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
344
-
345
  attn_output = torch.matmul(attn_weights, value_states)
346
 
347
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -351,9 +639,15 @@ class GemmoeAttention(nn.Module):
351
  )
352
 
353
  attn_output = attn_output.transpose(1, 2).contiguous()
354
- attn_output = attn_output.view(bsz, q_len, -1)
355
 
356
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
 
357
 
358
  if not output_attentions:
359
  attn_weights = None
@@ -366,9 +660,13 @@ class GemmoeFlashAttention2(GemmoeAttention):
366
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
367
  flash attention and deal with padding tokens in case the input contains any of them.
368
  """
 
369
  def __init__(self, *args, **kwargs):
370
  super().__init__(*args, **kwargs)
371
- # TODO: Remove this attribute once Flash Attention for RoCm is bumped to 2.1.
 
 
 
372
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
373
 
374
  def forward(
@@ -379,9 +677,17 @@ class GemmoeFlashAttention2(GemmoeAttention):
379
  past_key_value: Optional[Cache] = None,
380
  output_attentions: bool = False,
381
  use_cache: bool = False,
382
- cache_position: Optional[torch.LongTensor] = None,
383
  **kwargs,
384
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
385
  output_attentions = False
386
 
387
  bsz, q_len, _ = hidden_states.size()
@@ -397,13 +703,14 @@ class GemmoeFlashAttention2(GemmoeAttention):
397
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
398
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
 
400
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
401
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
402
 
403
- past_key_value = getattr(self, "past_key_value", past_key_value)
404
  if past_key_value is not None:
405
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
406
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
407
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
408
 
409
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
@@ -419,13 +726,14 @@ class GemmoeFlashAttention2(GemmoeAttention):
419
  # cast them back in the correct dtype just to be sure everything works as expected.
420
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
  # in fp32. (GemmoeRMSNorm handles it correctly)
 
422
  input_dtype = query_states.dtype
423
  if input_dtype == torch.float32:
424
- if torch.is_autocast_enabled():
425
- target_dtype = torch.get_autocast_gpu_dtype()
426
  # Handle the case where the model is quantized
427
- elif hasattr(self.config, "_pre_quantization_dtype"):
428
  target_dtype = self.config._pre_quantization_dtype
 
 
429
  else:
430
  target_dtype = self.q_proj.weight.dtype
431
 
@@ -434,6 +742,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
434
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
  f" {target_dtype}."
436
  )
 
437
  query_states = query_states.to(target_dtype)
438
  key_states = key_states.to(target_dtype)
439
  value_states = value_states.to(target_dtype)
@@ -442,7 +751,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
442
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
443
  )
444
 
445
- attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
446
  attn_output = self.o_proj(attn_output)
447
 
448
  if not output_attentions:
@@ -484,6 +793,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
484
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
485
  query_states, key_states, value_states, attention_mask, query_length
486
  )
 
487
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
488
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
489
 
@@ -499,6 +809,7 @@ class GemmoeFlashAttention2(GemmoeAttention):
499
  softmax_scale=softmax_scale,
500
  causal=causal,
501
  )
 
502
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
503
  else:
504
  attn_output = flash_attn_func(
@@ -509,15 +820,14 @@ class GemmoeFlashAttention2(GemmoeAttention):
509
 
510
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
511
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
512
-
513
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
 
514
  key_layer = index_first_axis(
515
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
516
  )
517
  value_layer = index_first_axis(
518
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
519
  )
520
-
521
  if query_length == kv_seq_len:
522
  query_layer = index_first_axis(
523
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
@@ -549,21 +859,11 @@ class GemmoeFlashAttention2(GemmoeAttention):
549
  class GemmoeSdpaAttention(GemmoeAttention):
550
  """
551
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
552
- GemmoeAttention as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
553
  SDPA API.
554
  """
555
 
556
- def repeat_kv(self, x, n_rep):
557
- """
558
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
559
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
560
- """
561
- batch, num_key_value_heads, slen, head_dim = x.shape
562
- if n_rep == 1:
563
- return x
564
- x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
565
- return x.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
566
-
567
  def forward(
568
  self,
569
  hidden_states: torch.Tensor,
@@ -572,15 +872,13 @@ class GemmoeSdpaAttention(GemmoeAttention):
572
  past_key_value: Optional[Cache] = None,
573
  output_attentions: bool = False,
574
  use_cache: bool = False,
575
- cache_position: Optional[torch.LongTensor] = None,
576
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
577
  if output_attentions:
578
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
579
- # logger.warning_once(
580
- "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
581
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
582
- # )
583
-
584
  return super().forward(
585
  hidden_states=hidden_states,
586
  attention_mask=attention_mask,
@@ -588,9 +886,8 @@ class GemmoeSdpaAttention(GemmoeAttention):
588
  past_key_value=past_key_value,
589
  output_attentions=output_attentions,
590
  use_cache=use_cache,
591
- cache_position=cache_position,
592
  )
593
-
594
  bsz, q_len, _ = hidden_states.size()
595
 
596
  query_states = self.q_proj(hidden_states)
@@ -601,48 +898,46 @@ class GemmoeSdpaAttention(GemmoeAttention):
601
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
602
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
603
 
604
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
605
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
 
 
 
 
606
 
607
- past_key_value = getattr(self, "past_key_value", past_key_value)
608
  if past_key_value is not None:
609
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
610
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
611
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
612
 
613
- key_states = self.repeat_kv(key_states, self.num_key_value_groups)
614
- value_states = self.repeat_kv(value_states, self.num_key_value_groups)
615
 
616
- causal_mask = attention_mask
617
- if attention_mask is not None and cache_position is not None:
618
- causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
619
-
620
- # Ensure query, key, and value states have the same dtype
621
- common_dtype = query_states.dtype
622
- key_states = key_states.to(dtype=common_dtype)
623
- value_states = value_states.to(dtype=common_dtype)
624
 
625
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
626
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
627
- if query_states.device.type == "cuda" and causal_mask is not None:
628
  query_states = query_states.contiguous()
629
  key_states = key_states.contiguous()
630
  value_states = value_states.contiguous()
631
 
632
- # Cast causal_mask to the same dtype as query_states
633
- if causal_mask is not None:
634
- causal_mask = causal_mask.to(dtype=query_states.dtype)
635
-
636
  attn_output = torch.nn.functional.scaled_dot_product_attention(
637
  query_states,
638
  key_states,
639
  value_states,
640
- attn_mask=causal_mask,
641
  dropout_p=self.attention_dropout if self.training else 0.0,
 
 
642
  )
643
 
644
  attn_output = attn_output.transpose(1, 2).contiguous()
645
- attn_output = attn_output.view(bsz, q_len, -1)
 
646
  attn_output = self.o_proj(attn_output)
647
 
648
  return attn_output, None, past_key_value
@@ -653,74 +948,17 @@ GEMMOE_ATTENTION_CLASSES = {
653
  "sdpa": GemmoeSdpaAttention,
654
  }
655
 
656
- class GemmoeBlockSparseTop2MLP(nn.Module):
657
- def __init__(self, config: GemmoeConfig):
658
- super().__init__()
659
- self.ffn_dim = config.intermediate_size
660
- self.hidden_dim = config.hidden_size
661
-
662
- self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
663
- self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
664
- self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
665
-
666
- self.act_fn = approx_gelu
667
-
668
- def forward(self, hidden_states):
669
- current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
670
- current_hidden_states = self.w2(current_hidden_states)
671
- return current_hidden_states
672
-
673
-
674
- class GemmoeSparseMoeBlock(nn.Module):
675
- def __init__(self, config):
676
- super().__init__()
677
- self.hidden_dim = config.hidden_size
678
- self.ffn_dim = config.intermediate_size
679
- self.num_experts = config.num_local_experts
680
- self.top_k = 2
681
-
682
- # gating
683
- self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
684
-
685
- self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
686
-
687
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
688
- batch_size, sequence_length, hidden_dim = hidden_states.shape
689
- hidden_states = hidden_states.view(-1, hidden_dim)
690
-
691
- # router_logits: (batch * sequence_length, n_experts)
692
- router_logits = self.gate(hidden_states)
693
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
694
- topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
695
- topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
696
-
697
- # we cast back to the input dtype
698
- topk_weight = topk_weight.to(hidden_states.dtype)
699
 
700
- hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
701
-
702
- y = torch.empty_like(hidden_states)
703
-
704
- flat_topk_idx = topk_idx.view(-1)
705
- for i in range(self.num_experts):
706
- expert = self.experts[i]
707
- expert_output = expert(hidden_states[flat_topk_idx == i])
708
- y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
709
-
710
- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
711
-
712
- final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
713
- return final_hidden_states, router_logits
714
-
715
-
716
  class GemmoeDecoderLayer(nn.Module):
717
  def __init__(self, config: GemmoeConfig, layer_idx: int):
718
  super().__init__()
719
  self.hidden_size = config.hidden_size
720
 
721
- self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
722
 
723
- self.block_sparse_moe = GemmoeSparseMoeBlock(config)
 
 
724
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
725
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
726
 
@@ -731,9 +969,7 @@ class GemmoeDecoderLayer(nn.Module):
731
  position_ids: Optional[torch.LongTensor] = None,
732
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
733
  output_attentions: Optional[bool] = False,
734
- output_router_logits: Optional[bool] = False,
735
  use_cache: Optional[bool] = False,
736
- cache_position: Optional[torch.LongTensor] = None,
737
  **kwargs,
738
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
739
  """
@@ -749,16 +985,13 @@ class GemmoeDecoderLayer(nn.Module):
749
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
750
  (see `past_key_values`).
751
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
752
- output_router_logits (`bool`, *optional*):
753
- Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
754
- should not be returned during inference.
755
  """
756
  if "padding_mask" in kwargs:
757
  warnings.warn(
758
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
759
  )
760
-
761
  residual = hidden_states
 
762
  hidden_states = self.input_layernorm(hidden_states)
763
 
764
  # Self Attention
@@ -769,7 +1002,6 @@ class GemmoeDecoderLayer(nn.Module):
769
  past_key_value=past_key_value,
770
  output_attentions=output_attentions,
771
  use_cache=use_cache,
772
- cache_position=cache_position,
773
  **kwargs,
774
  )
775
  hidden_states = residual + hidden_states
@@ -777,9 +1009,8 @@ class GemmoeDecoderLayer(nn.Module):
777
  # Fully Connected
778
  residual = hidden_states
779
  hidden_states = self.post_attention_layernorm(hidden_states)
780
- hidden_states, router_logits = self.block_sparse_moe(hidden_states)
781
  hidden_states = residual + hidden_states
782
-
783
 
784
  outputs = (hidden_states,)
785
 
@@ -789,15 +1020,23 @@ class GemmoeDecoderLayer(nn.Module):
789
  if use_cache:
790
  outputs += (present_key_value,)
791
 
792
- if output_router_logits:
793
- outputs += (router_logits,)
794
-
795
  return outputs
796
 
 
797
  GEMMOE_START_DOCSTRING = r"""
798
- This model inherits from [PreTrainedModel]. Check the superclass documentation for the generic methods the
799
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
800
- etc.)
 
 
 
 
 
 
 
 
 
 
801
  """
802
 
803
  @add_start_docstrings(
@@ -806,52 +1045,94 @@ GEMMOE_START_DOCSTRING,
806
  )
807
 
808
  class GemmoePreTrainedModel(PreTrainedModel):
809
- config_class = GemmoeConfig
810
- base_model_prefix = "model"
811
- supports_gradient_checkpointing = True
812
- _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"]
813
- _no_split_modules = ["GemmoeDecoderLayer"]
814
- _skip_keys_device_placement = ["past_key_values", "causal_mask"]
815
- _supports_flash_attn_2 = True
816
- _supports_sdpa = True
817
- _supports_cache_class = True
818
-
819
- def _init_weights(self, module):
820
- std = self.config.initializer_range
821
- if isinstance(module, nn.Linear):
822
- module.weight.data.normal_(mean=0.0, std=std)
823
- if module.bias is not None:
824
- module.bias.data.zero_()
825
- elif isinstance(module, nn.Embedding):
826
- module.weight.data.normal_(mean=0.0, std=std)
827
- if module.padding_idx is not None:
828
- module.weight.data[module.padding_idx].zero_()
829
-
830
- def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
831
- if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
832
- raise ValueError(
833
- "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
834
- "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
835
- )
836
- if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
837
- causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
838
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
839
-
840
- for layer in self.model.layers:
841
- weights = layer.self_attn.o_proj.weight
842
- layer.self_attn.past_key_value = cache_cls(
843
- self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
844
- )
845
-
846
- def _reset_cache(self):
847
- for layer in self.model.layers:
848
- layer.self_attn.past_key_value = None
849
-
850
- GEMMOE_INPUTS_DOCSTRING = r"""
851
- Args:
852
- input_ids (torch.LongTensor of shape (batch_size, sequence_length)):
853
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
854
- it.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
  """
856
 
857
  @add_start_docstrings(
@@ -860,263 +1141,168 @@ GEMMOE_START_DOCSTRING,
860
  )
861
 
862
  class GemmoeModel(GemmoePreTrainedModel):
863
- """
864
- Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [GemmoeDecoderLayer]Args:
865
- config: GemmoeConfig
866
- """
867
-
868
-
869
- def __init__(self, config: GemmoeConfig):
870
- super().__init__(config)
871
- self.padding_idx = config.pad_token_id
872
- self.vocab_size = config.vocab_size
873
-
874
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
875
- self.layers = nn.ModuleList(
876
- [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
877
- )
878
-
879
- self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
880
-
881
- self.gradient_checkpointing = False
882
-
883
- # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
884
- # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
885
- causal_mask = torch.full(
886
- (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
887
- )
888
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
889
-
890
- # Initialize weights and apply final processing
891
- self.post_init()
892
-
893
- def get_input_embeddings(self):
894
- return self.embed_tokens
895
-
896
- def set_input_embeddings(self, value):
897
- self.embed_tokens = value
898
-
899
- @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
900
- @replace_return_docstrings(output_type=MoeModelOutputWithPast, config_class=_CONFIG_FOR_DOC)
901
- def forward(
902
- self,
903
- input_ids: torch.LongTensor = None,
904
- attention_mask: Optional[torch.Tensor] = None,
905
- position_ids: Optional[torch.LongTensor] = None,
906
- past_key_values: Optional[List[torch.FloatTensor]] = None,
907
- inputs_embeds: Optional[torch.FloatTensor] = None,
908
- use_cache: Optional[bool] = None,
909
- output_attentions: Optional[bool] = None,
910
- output_hidden_states: Optional[bool] = None,
911
- output_router_logits: Optional[bool] = None,
912
- return_dict: Optional[bool] = None,
913
- cache_position: Optional[torch.LongTensor] = None,
914
- ) -> Union[Tuple, MoeModelOutputWithPast]:
915
- """
916
- Forward pass of the sequence classification model.
917
-
918
- Args:
919
- input_ids: Input token IDs.
920
- attention_mask: Attention mask.
921
- position_ids: Position IDs.
922
- past_key_values: Past key-value pairs.
923
- inputs_embeds: Input embeddings.
924
- labels: Labels for sequence classification.
925
- use_cache: Whether to use cache.
926
- output_attentions: Whether to output attentions.
927
- output_hidden_states: Whether to output hidden states.
928
- return_dict: Whether to return a dictionary or tuple.
929
-
930
- Returns:
931
- Output of the sequence classification model.
932
- """
933
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
934
- output_hidden_states = (
935
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
936
- )
937
- output_router_logits = (
938
- output_router_logits if output_router_logits is not None else self.config.output_router_logits
939
- )
940
- use_cache = use_cache if use_cache is not None else self.config.use_cache
941
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
942
-
943
- if (input_ids is None) ^ (inputs_embeds is not None):
944
- raise ValueError(
945
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
946
- )
947
-
948
- if self.gradient_checkpointing and self.training and use_cache:
949
- logger.warning_once(
950
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
951
- )
952
- use_cache = False
953
-
954
- if inputs_embeds is None:
955
- inputs_embeds = self.embed_tokens(input_ids)
956
-
957
- past_seen_tokens = 0
958
- if use_cache: # kept for BC (cache positions)
959
- if not isinstance(past_key_values, StaticCache):
960
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
961
- past_seen_tokens = past_key_values.get_seq_length()
962
-
963
- if cache_position is None:
964
- cache_position = torch.arange(
965
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
966
- )
967
-
968
- if position_ids is None:
969
- position_ids = cache_position.unsqueeze(0)
970
-
971
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
972
-
973
- hidden_states = inputs_embeds
974
-
975
- # Normalize
976
- scale_factor = torch.tensor(math_sqrt(self.config.hidden_size), dtype=hidden_states.dtype)
977
- hidden_states = hidden_states * scale_factor
978
- # Decoder layers
979
- all_hidden_states = () if output_hidden_states else None
980
- all_self_attns = () if output_attentions else None
981
- all_router_logits = () if output_router_logits else None
982
- next_decoder_cache = None
983
-
984
- for decoder_layer in self.layers:
985
- if output_hidden_states:
986
- all_hidden_states += (hidden_states,)
987
-
988
- if self.gradient_checkpointing and self.training:
989
- layer_outputs = self._gradient_checkpointing_func(
990
- decoder_layer.__call__,
991
- hidden_states,
992
- causal_mask,
993
- position_ids,
994
- past_key_values,
995
- output_attentions,
996
- output_router_logits,
997
- use_cache,
998
- cache_position,
999
- )
1000
- else:
1001
- layer_outputs = decoder_layer(
1002
- hidden_states,
1003
- attention_mask=causal_mask,
1004
- position_ids=position_ids,
1005
- past_key_value=past_key_values,
1006
- output_attentions=output_attentions,
1007
- output_router_logits=output_router_logits,
1008
- use_cache=use_cache,
1009
- cache_position=cache_position,
1010
- )
1011
-
1012
- hidden_states = layer_outputs[0]
1013
- if use_cache:
1014
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1015
- if output_attentions:
1016
- all_self_attns += (layer_outputs[1],)
1017
- if output_router_logits:
1018
- all_router_logits += (layer_outputs[-1],)
1019
-
1020
- hidden_states = self.norm(hidden_states)
1021
-
1022
- # Add hidden states from the last decoder layer
1023
- if output_hidden_states:
1024
- all_hidden_states += (hidden_states,)
1025
-
1026
- next_cache = None
1027
- if use_cache:
1028
- next_cache = (
1029
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1030
- )
1031
-
1032
- if not return_dict:
1033
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None)
1034
-
1035
- return MoeModelOutputWithPast(
1036
- last_hidden_state=hidden_states,
1037
- past_key_values=next_cache,
1038
- hidden_states=all_hidden_states,
1039
- attentions=all_self_attns,
1040
- router_logits=all_router_logits
1041
- )
1042
-
1043
- def _update_causal_mask(self, attention_mask, input_tensor):
1044
- """
1045
- Update the causal mask based on the attention mask and input tensor.
1046
-
1047
- Args:
1048
- attention_mask (torch.Tensor): The attention mask.
1049
- input_tensor (torch.Tensor): The input tensor.
1050
-
1051
- Returns:
1052
- torch.Tensor: The updated causal mask.
1053
- """
1054
-
1055
- if self.config._attn_implementation == "flash_attention_2":
1056
- if attention_mask is not None and 0.0 in attention_mask:
1057
- return attention_mask
1058
- return None
1059
-
1060
- batch_size, seq_length = input_tensor.shape[:2]
1061
- dtype = input_tensor.dtype
1062
- device = input_tensor.device
1063
-
1064
- # support going beyond cached `max_position_embedding`
1065
- if seq_length > self.causal_mask.shape[-1]:
1066
- logger.info(f"Resizing causal mask buffer from {self.causal_mask.shape[-1]} to {2 * self.causal_mask.shape[-1]}")
1067
- causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1068
- self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1069
-
1070
- # We use the current dtype to avoid any overflows
1071
- min_dtype = torch.finfo(dtype).min
1072
- causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
1073
- causal_mask = causal_mask.to(dtype=dtype, device=device)
1074
-
1075
- if attention_mask is not None and attention_mask.dim() == 2:
1076
- mask_length = attention_mask.shape[-1]
1077
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1078
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1079
-
1080
- if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1081
- # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1082
- is_tracing = (
1083
- torch.jit.is_tracing()
1084
- or isinstance(input_tensor, torch.fx.Proxy)
1085
- or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1086
- )
1087
-
1088
- if not is_tracing and torch.any(attention_mask != 1):
1089
- # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1090
- # using left padding. This is required by
1091
- # F.scaled_dot_product_attention memory-efficient attention path.
1092
- # Details: https://github.com/pytorch/pytorch/issues/110213
1093
- causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
1094
-
1095
- return causal_mask
1096
-
1097
- class GemmoeForCausalLM(GemmoePreTrainedModel):
1098
- r"""
1099
- The Gemmoe Model transformer with a language modeling head on top for causal language modeling (CLM).
1100
 
1101
  Args:
1102
- config (GemmoeConfig): The configuration object for the Gemmoe model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1103
 
1104
- Example usage:
1105
- ```python
1106
- >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1107
 
1108
- >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1109
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1110
 
1111
- >>> prompt = "What is your favorite condiment?"
1112
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1113
 
1114
- >>> # Generate
1115
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1116
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1117
- "What is your favorite condiment?"
1118
- ```
1119
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  _tied_weights_keys = ["lm_head.weight"]
1121
 
1122
  def __init__(self, config):
@@ -1124,9 +1310,6 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1124
  self.model = GemmoeModel(config)
1125
  self.vocab_size = config.vocab_size
1126
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1127
- self.router_aux_loss_coef = config.router_aux_loss_coef
1128
- self.num_experts = 8
1129
- self.num_experts_per_tok = config.num_experts_per_tok
1130
 
1131
  # Initialize weights and apply final processing
1132
  self.post_init()
@@ -1149,8 +1332,8 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1149
  def get_decoder(self):
1150
  return self.model
1151
 
1152
- @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1153
- @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1154
  def forward(
1155
  self,
1156
  input_ids: torch.LongTensor = None,
@@ -1162,16 +1345,14 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1162
  use_cache: Optional[bool] = None,
1163
  output_attentions: Optional[bool] = None,
1164
  output_hidden_states: Optional[bool] = None,
1165
- output_router_logits: Optional[bool] = None,
1166
  return_dict: Optional[bool] = None,
1167
- cache_position: Optional[torch.LongTensor] = None,
1168
- ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
1169
  r"""
1170
  Args:
1171
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1172
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1173
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1174
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1175
 
1176
  Returns:
1177
 
@@ -1180,26 +1361,24 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1180
  ```python
1181
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1182
 
1183
- >>> model = GemmoeForCausalLM.from_pretrained("google/gemmoe-7b")
1184
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemmoe-7b")
1185
 
1186
- >>> prompt = "What is your favorite condiment?"
1187
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1188
 
1189
  >>> # Generate
1190
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1191
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1192
- "What is your favorite condiment?"
1193
  ```"""
1194
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1195
- output_router_logits = (
1196
- output_router_logits if output_router_logits is not None else getattr(self.config, "output_router_logits", False)
1197
- )
1198
  output_hidden_states = (
1199
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1200
  )
1201
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1202
 
 
1203
  outputs = self.model(
1204
  input_ids=input_ids,
1205
  attention_mask=attention_mask,
@@ -1209,61 +1388,46 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1209
  use_cache=use_cache,
1210
  output_attentions=output_attentions,
1211
  output_hidden_states=output_hidden_states,
1212
- output_router_logits=output_router_logits,
1213
  return_dict=return_dict,
1214
- cache_position=cache_position,
1215
  )
1216
 
1217
  hidden_states = outputs[0]
1218
-
1219
- # Ensure hidden_states and lm_head have compatible dtypes
1220
- hidden_states = hidden_states.to(dtype=self.lm_head.weight.dtype)
1221
-
1222
- logits = self.lm_head(hidden_states)
 
 
1223
 
1224
  loss = None
1225
  if labels is not None:
 
1226
  shift_logits = logits[..., :-1, :].contiguous()
1227
  shift_labels = labels[..., 1:].contiguous()
 
1228
  loss_fct = CrossEntropyLoss()
1229
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1230
  shift_labels = shift_labels.view(-1)
 
1231
  shift_labels = shift_labels.to(shift_logits.device)
1232
  loss = loss_fct(shift_logits, shift_labels)
1233
 
1234
- aux_loss = None
1235
- if output_router_logits:
1236
- router_logits = outputs.router_logits if return_dict else outputs[-1]
1237
- if router_logits is not None:
1238
- aux_loss = load_balancing_loss_func(
1239
- router_logits,
1240
- self.num_experts,
1241
- self.num_experts_per_tok,
1242
- attention_mask,
1243
- )
1244
- if labels is not None:
1245
- loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
1246
-
1247
  if not return_dict:
1248
  output = (logits,) + outputs[1:]
1249
- if aux_loss is not None:
1250
- output = (aux_loss,) + output
1251
  return (loss,) + output if loss is not None else output
1252
 
1253
- return MoeCausalLMOutputWithPast(
1254
  loss=loss,
1255
- aux_loss=aux_loss,
1256
  logits=logits,
1257
  past_key_values=outputs.past_key_values,
1258
  hidden_states=outputs.hidden_states,
1259
  attentions=outputs.attentions,
1260
- router_logits=outputs.router_logits,
1261
  )
1262
 
1263
  def prepare_inputs_for_generation(
1264
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1265
  ):
1266
- past_length = 0
1267
  if past_key_values is not None:
1268
  if isinstance(past_key_values, Cache):
1269
  cache_length = past_key_values.get_seq_length()
@@ -1273,11 +1437,19 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1273
  cache_length = past_length = past_key_values[0][0].shape[2]
1274
  max_cache_length = None
1275
 
 
 
 
 
1276
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1277
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
 
1278
  elif past_length < input_ids.shape[1]:
1279
  input_ids = input_ids[:, past_length:]
1280
-
 
 
1281
  if (
1282
  max_cache_length is not None
1283
  and attention_mask is not None
@@ -1287,37 +1459,26 @@ class GemmoeForCausalLM(GemmoePreTrainedModel):
1287
 
1288
  position_ids = kwargs.get("position_ids", None)
1289
  if attention_mask is not None and position_ids is None:
 
1290
  position_ids = attention_mask.long().cumsum(-1) - 1
1291
  position_ids.masked_fill_(attention_mask == 0, 1)
1292
  if past_key_values:
1293
  position_ids = position_ids[:, -input_ids.shape[1] :]
1294
 
1295
- if self.generation_config.cache_implementation == "static":
1296
- cache_position = kwargs.get("cache_position", None)
1297
- if cache_position is None:
1298
- past_length = 0
1299
- else:
1300
- past_length = cache_position[-1] + 1
1301
- input_ids = input_ids[:, -1].unsqueeze(-1)
1302
- position_ids = position_ids[:, -1].unsqueeze(-1)
1303
-
1304
- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1305
-
1306
  if inputs_embeds is not None and past_key_values is None:
1307
  model_inputs = {"inputs_embeds": inputs_embeds}
1308
  else:
1309
- model_inputs = {"input_ids": input_ids.contiguous()}
1310
 
1311
  model_inputs.update(
1312
  {
1313
- "position_ids": position_ids.contiguous(),
1314
- "cache_position": cache_position,
1315
  "past_key_values": past_key_values,
1316
  "use_cache": kwargs.get("use_cache"),
1317
  "attention_mask": attention_mask,
1318
  }
1319
  )
1320
-
1321
  return model_inputs
1322
 
1323
  @staticmethod
@@ -1350,6 +1511,7 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1350
  self.num_labels = config.num_labels
1351
  self.model = GemmoeModel(config)
1352
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
1353
  # Initialize weights and apply final processing
1354
  self.post_init()
1355
 
@@ -1359,8 +1521,7 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1359
  def set_input_embeddings(self, value):
1360
  self.model.embed_tokens = value
1361
 
1362
- @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
1363
- @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
1364
  def forward(
1365
  self,
1366
  input_ids: torch.LongTensor = None,
@@ -1374,25 +1535,14 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1374
  output_hidden_states: Optional[bool] = None,
1375
  return_dict: Optional[bool] = None,
1376
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1377
- """
1378
- Forward pass of the sequence classification model.
1379
-
1380
- Args:
1381
- input_ids (torch.LongTensor, optional): Input token IDs.
1382
- attention_mask (torch.Tensor, optional): Attention mask.
1383
- position_ids (torch.LongTensor, optional): Position IDs.
1384
- past_key_values (List[torch.FloatTensor], optional): Past key-value pairs.
1385
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
1386
- labels (torch.LongTensor, optional): Labels for sequence classification.
1387
- use_cache (bool, optional): Whether to use cache.
1388
- output_attentions (bool, optional): Whether to output attentions.
1389
- output_hidden_states (bool, optional): Whether to output hidden states.
1390
- return_dict (bool, optional): Whether to return a dictionary or tuple.
1391
-
1392
- Returns:
1393
- Union[Tuple, SequenceClassifierOutputWithPast]: Output of the sequence classification model.
1394
  """
1395
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
1396
  transformer_outputs = self.model(
1397
  input_ids,
1398
  attention_mask=attention_mask,
@@ -1418,8 +1568,9 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1418
  sequence_lengths = -1
1419
  else:
1420
  if input_ids is not None:
1421
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1422
- sequence_lengths = sequence_lengths.clamp(min=0).to(logits.device)
 
1423
  else:
1424
  sequence_lengths = -1
1425
 
@@ -1448,7 +1599,6 @@ class GemmoeForSequenceClassification(GemmoePreTrainedModel):
1448
  elif self.config.problem_type == "multi_label_classification":
1449
  loss_fct = BCEWithLogitsLoss()
1450
  loss = loss_fct(pooled_logits, labels)
1451
-
1452
  if not return_dict:
1453
  output = (pooled_logits,) + transformer_outputs[1:]
1454
  return ((loss,) + output) if loss is not None else output
 
26
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
 
28
  from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache
30
  from transformers.modeling_attn_mask_utils import (
31
+ AttentionMaskConverter,
32
+ _prepare_4d_attention_mask,
33
  _prepare_4d_causal_attention_mask,
34
+ _prepare_4d_causal_attention_mask_for_sdpa,
35
  )
36
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast, BaseModelOutputWithPast, CausalLMOutputWithPast
37
  from transformers.modeling_utils import PreTrainedModel
38
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
39
  from transformers.utils import (
 
63
 
64
  _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
65
 
 
66
  logger = logging.get_logger(__name__)
67
 
68
  _CONFIG_FOR_DOC = "GemmoeConfig"
 
158
  max_seqlen_in_batch,
159
  )
160
 
161
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
162
+ warnings.warn(
163
+ "Calling `transformers.models.Gemmoe.modeling_Gemmoe._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
164
+ )
165
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
166
+
167
+ def _make_causal_mask(
168
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
169
+ ):
170
+ warnings.warn(
171
+ "Calling `transformers.models.Gemmoe.modeling_Gemmoe._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.Gemmoe.modeling_Gemmoe.AttentionMaskConverter._make_causal_mask"
172
+ )
173
+ return AttentionMaskConverter._make_causal_mask(
174
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
175
+ )
176
+
177
 
178
 
179
  class GemmoeRMSNorm(nn.Module):
180
+ def __init__(self, hidden_size, eps=1e-6):
181
+ """
182
+ GemmoeRMSNorm is equivalent to T5LayerNorm
183
+ """
184
  super().__init__()
185
+ self.weight = nn.Parameter(torch.ones(hidden_size))
186
+ self.variance_epsilon = eps
 
 
 
187
 
188
+ def forward(self, hidden_states):
189
+ input_dtype = hidden_states.dtype
190
+ hidden_states = hidden_states.to(torch.float32)
191
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
192
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
193
+ return self.weight * hidden_states.to(input_dtype)
194
 
195
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
196
 
197
  class GemmoeRotaryEmbedding(nn.Module):
198
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
199
  super().__init__()
200
+
201
  self.dim = dim
202
  self.max_position_embeddings = max_position_embeddings
203
  self.base = base
204
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
205
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
206
+
207
+ # Build here to make `torch.jit.trace` work.
208
+ self._set_cos_sin_cache(
209
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
210
+ )
211
+ self.max_seq_len_cached = None
212
+
213
 
214
  def _set_cos_sin_cache(self, seq_len, device, dtype):
215
  self.max_seq_len_cached = seq_len
216
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
217
+
218
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
219
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
220
+ emb = torch.cat((freqs, freqs), dim=-1)
221
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
222
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
223
+
224
+ def forward(self, x, seq_len=None):
225
+ # x: [bs, num_attention_heads, seq_len, head_size]
226
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
 
 
 
 
 
 
227
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
228
+
229
  return (
230
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
231
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
232
  )
233
+
234
+ class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
235
+ """GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
236
+
237
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
238
+ self.scaling_factor = scaling_factor
239
+ super().__init__(dim, max_position_embeddings, base, device)
240
+
241
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
242
+ self.max_seq_len_cached = seq_len
243
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
244
+ t = t / self.scaling_factor
245
+
246
+ freqs = torch.outer(t, self.inv_freq)
247
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
248
+ emb = torch.cat((freqs, freqs), dim=-1)
249
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
250
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
251
+
252
+ class GemmoeDynamicNTKScalingRotaryEmbedding(GemmoeRotaryEmbedding):
253
+ """GemmoeRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
254
+
255
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
256
+ self.scaling_factor = scaling_factor
257
+ super().__init__(dim, max_position_embeddings, base, device)
258
+
259
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
260
+ self.max_seq_len_cached = seq_len
261
+
262
+ if seq_len > self.max_position_embeddings:
263
+ base = self.base * (
264
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
265
+ ) ** (self.dim / (self.dim - 2))
266
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
267
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
268
+
269
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
270
+
271
+ freqs = torch.outer(t, self.inv_freq)
272
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
275
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
276
 
277
  def rotate_half(x):
278
  """Rotates half the hidden dims of the input."""
 
280
  x2 = x[..., x.shape[-1] // 2 :]
281
  return torch.cat((-x2, x1), dim=-1)
282
 
283
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
284
+ """Applies Rotary Position Embedding to the query and key tensors.
285
+
286
+ Args:
287
+ q (`torch.Tensor`): The query tensor.
288
+ k (`torch.Tensor`): The key tensor.
289
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
290
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
291
+ position_ids (`torch.Tensor`):
292
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
293
+ used to pass offsetted position ids when working with a KV-cache.
294
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
295
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
296
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
297
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
298
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
299
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
300
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
301
+ Returns:
302
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
303
+ """
304
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
305
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
306
  q_embed = (q * cos) + (rotate_half(q) * sin)
307
  k_embed = (k * cos) + (rotate_half(k) * sin)
308
  return q_embed, k_embed
309
 
310
+ class GemmoeMLP(nn.Module):
311
+ def __init__(self, config, hidden_size = None, intermediate_size = None):
312
+ super().__init__()
313
+ self.config = config
314
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
315
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
316
+
317
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
318
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
319
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
320
+ self.act_fn = ACT2FN[config.hidden_act]
321
+
322
+ def forward(self, x):
323
+ if self.config.pretraining_tp > 1:
324
+ slice = self.intermediate_size // self.config.pretraining_tp
325
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
326
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
327
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
328
+
329
+ gate_proj = torch.cat(
330
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
331
+ )
332
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
333
+
334
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
335
+ down_proj = [
336
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
337
+ ]
338
+ down_proj = sum(down_proj)
339
+ else:
340
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
341
+
342
+ return down_proj
343
+
344
+ class MoEGate(nn.Module):
345
+ def __init__(self, config):
346
+ super().__init__()
347
+ self.config = config
348
+ self.top_k = config.num_experts_per_tok
349
+ self.n_routed_experts = config.n_routed_experts
350
+
351
+ self.scoring_func = config.scoring_func
352
+ self.alpha = config.aux_loss_alpha
353
+ self.seq_aux = config.seq_aux
354
+
355
+ # topk selection algorithm
356
+ self.norm_topk_prob = config.norm_topk_prob
357
+ self.gating_dim = config.hidden_size
358
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
359
+ self.reset_parameters()
360
+
361
+ def reset_parameters(self) -> None:
362
+ import torch.nn.init as init
363
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
364
+
365
+ def forward(self, hidden_states):
366
+ bsz, seq_len, h = hidden_states.shape
367
+ ### compute gating score
368
+ hidden_states = hidden_states.view(-1, h)
369
+ logits = F.linear(hidden_states, self.weight, None)
370
+ if self.scoring_func == 'softmax':
371
+ scores = logits.softmax(dim=-1)
372
+ else:
373
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
374
+
375
+ ### select top-k experts
376
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
377
+
378
+ ### norm gate to sum 1
379
+ if self.top_k > 1 and self.norm_topk_prob:
380
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
381
+ topk_weight = topk_weight / denominator
382
+
383
+ ### expert-level computation auxiliary loss
384
+ if self.training and self.alpha > 0.0:
385
+ scores_for_aux = scores
386
+ aux_topk = self.top_k
387
+ # always compute aux loss based on the naive greedy topk method
388
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
389
+ if self.seq_aux:
390
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
391
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
392
+ ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
393
+ aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
394
+ else:
395
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
396
+ ce = mask_ce.float().mean(0)
397
+ Pi = scores_for_aux.mean(0)
398
+ fi = ce * self.n_routed_experts
399
+ aux_loss = (Pi * fi).sum() * self.alpha
400
+ else:
401
+ aux_loss = None
402
+ return topk_idx, topk_weight, aux_loss
403
+
404
+ class AddAuxiliaryLoss(torch.autograd.Function):
405
+ """
406
+ The trick function of adding auxiliary (aux) loss,
407
+ which includes the gradient of the aux loss during backpropagation.
408
+ """
409
+ @staticmethod
410
+ def forward(ctx, x, loss):
411
+ assert loss.numel() == 1
412
+ ctx.dtype = loss.dtype
413
+ ctx.required_aux_loss = loss.requires_grad
414
+ return x
415
+
416
+ @staticmethod
417
+ def backward(ctx, grad_output):
418
+ grad_loss = None
419
+ if ctx.required_aux_loss:
420
+ grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
421
+ return grad_output, grad_loss
422
+
423
+ class GemMoE(nn.Module):
424
+ """
425
+ A mixed expert module containing shared experts.
426
+ """
427
+ def __init__(self, config):
428
+ super().__init__()
429
+ self.config = config
430
+ self.num_experts_per_tok = config.num_experts_per_tok
431
+ self.experts = nn.ModuleList([GemmoeMLP(config, intermediate_size = config.moe_intermediate_size) for i in range(config.n_routed_experts)])
432
+ self.gate = MoEGate(config)
433
+ if config.n_shared_experts is not None:
434
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
435
+ self.shared_experts = GemmoeMLP(config=config, intermediate_size = intermediate_size)
436
+
437
+ def forward(self, hidden_states):
438
+ identity = hidden_states
439
+ orig_shape = hidden_states.shape
440
+ topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
441
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
442
+ flat_topk_idx = topk_idx.view(-1)
443
+ if self.training:
444
+ hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
445
+ y = torch.empty_like(hidden_states)
446
+ for i, expert in enumerate(self.experts):
447
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
448
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
449
+ y = y.view(*orig_shape)
450
+ y = AddAuxiliaryLoss.apply(y, aux_loss)
451
+ else:
452
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
453
+ if self.config.n_shared_experts is not None:
454
+ y = y + self.shared_experts(identity)
455
+ return y
456
+
457
+ @torch.no_grad()
458
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
459
+ expert_cache = torch.zeros_like(x)
460
+ idxs = flat_expert_indices.argsort()
461
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
462
+ token_idxs = idxs // self.num_experts_per_tok
463
+ for i, end_idx in enumerate(tokens_per_expert):
464
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
465
+ if start_idx == end_idx:
466
+ continue
467
+ expert = self.experts[i]
468
+ exp_token_idx = token_idxs[start_idx:end_idx]
469
+ expert_tokens = x[exp_token_idx]
470
+ expert_out = expert(expert_tokens)
471
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
472
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
473
+ return expert_cache
474
+
475
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
476
  """
477
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
478
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
 
482
  return hidden_states
483
  hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
484
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
485
+
486
 
487
  class GemmoeAttention(nn.Module):
488
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
 
 
 
 
 
 
489
 
490
  def __init__(self, config: GemmoeConfig, layer_idx: Optional[int] = None):
491
  super().__init__()
 
493
  self.layer_idx = layer_idx
494
  if layer_idx is None:
495
  logger.warning_once(
496
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
497
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
498
  "when creating this class."
499
  )
500
+
501
  self.attention_dropout = config.attention_dropout
502
  self.hidden_size = config.hidden_size
503
  self.num_heads = config.num_attention_heads
504
+ self.head_dim = self.hidden_size // self.num_heads
505
  self.num_key_value_heads = config.num_key_value_heads
506
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
507
  self.max_position_embeddings = config.max_position_embeddings
508
  self.rope_theta = config.rope_theta
509
  self.is_causal = True
510
 
511
+ if (self.head_dim * self.num_heads) != self.hidden_size:
512
  raise ValueError(
513
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
514
  f" and `num_heads`: {self.num_heads})."
515
  )
516
+
517
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
518
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
519
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
520
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
521
+ self._init_rope()
522
+
523
+ def _init_rope(self):
524
+ if self.config.rope_scaling is None:
525
+ self.rotary_emb = GemmoeRotaryEmbedding(
526
+ self.head_dim,
527
+ max_position_embeddings=self.max_position_embeddings,
528
+ base=self.rope_theta,
529
+ )
530
+ else:
531
+ scaling_type = self.config.rope_scaling["type"]
532
+ scaling_factor = self.config.rope_scaling["factor"]
533
+ if scaling_type == "linear":
534
+ self.rotary_emb = GemmoeLinearScalingRotaryEmbedding(
535
+ self.head_dim,
536
+ max_position_embeddings=self.max_position_embeddings,
537
+ scaling_factor=scaling_factor,
538
+ base=self.rope_theta,
539
+ )
540
+ elif scaling_type == "dynamic":
541
+ self.rotary_emb = GemmoeDynamicNTKScalingRotaryEmbedding(
542
+ self.head_dim,
543
+ max_position_embeddings=self.max_position_embeddings,
544
+ scaling_factor=scaling_factor,
545
+ base=self.rope_theta,
546
+ )
547
+ else:
548
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
549
+
550
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
551
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
552
 
553
  def forward(
554
  self,
 
558
  past_key_value: Optional[Cache] = None,
559
  output_attentions: bool = False,
560
  use_cache: bool = False,
 
561
  **kwargs,
562
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
563
+ if "padding_mask" in kwargs:
564
+ warnings.warn(
565
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
566
+ )
 
 
 
 
 
 
 
 
567
 
 
 
 
 
 
 
568
  bsz, q_len, _ = hidden_states.size()
569
 
570
+ if self.config.pretraining_tp > 1:
571
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
572
+ query_slices = self.q_proj.weight.split(
573
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
574
+ )
575
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
576
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
577
+
578
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
579
+ query_states = torch.cat(query_states, dim=-1)
580
+
581
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
582
+ key_states = torch.cat(key_states, dim=-1)
583
+
584
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
585
+ value_states = torch.cat(value_states, dim=-1)
586
+
587
+ else:
588
+ query_states = self.q_proj(hidden_states)
589
+ key_states = self.k_proj(hidden_states)
590
+ value_states = self.v_proj(hidden_states)
591
 
592
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
593
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
594
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
595
 
596
+ kv_seq_len = key_states.shape[-2]
597
+ if past_key_value is not None:
598
+ if self.layer_idx is None:
599
+ raise ValueError(
600
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
601
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
602
+ "with a layer index."
603
+ )
604
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
605
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
606
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
607
 
608
  if past_key_value is not None:
609
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
610
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
611
 
612
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
613
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
614
 
615
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
616
 
617
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
618
+ raise ValueError(
619
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
620
+ f" {attn_weights.size()}"
621
+ )
622
+
623
+ if attention_mask is not None:
624
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
625
+ raise ValueError(
626
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
627
+ )
628
+ attn_weights = attn_weights + attention_mask
629
 
630
  # upcast attention to fp32
631
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
632
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
633
  attn_output = torch.matmul(attn_weights, value_states)
634
 
635
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
639
  )
640
 
641
  attn_output = attn_output.transpose(1, 2).contiguous()
 
642
 
643
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
644
+
645
+ if self.config.pretraining_tp > 1:
646
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
647
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
648
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
649
+ else:
650
+ attn_output = self.o_proj(attn_output)
651
 
652
  if not output_attentions:
653
  attn_weights = None
 
660
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
661
  flash attention and deal with padding tokens in case the input contains any of them.
662
  """
663
+
664
  def __init__(self, *args, **kwargs):
665
  super().__init__(*args, **kwargs)
666
+
667
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
668
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
669
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
670
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
671
 
672
  def forward(
 
677
  past_key_value: Optional[Cache] = None,
678
  output_attentions: bool = False,
679
  use_cache: bool = False,
 
680
  **kwargs,
681
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
682
+ # GemmoeFlashAttention2 attention does not support output_attentions
683
+ if "padding_mask" in kwargs:
684
+ warnings.warn(
685
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
686
+ )
687
+
688
+ # overwrite attention_mask with padding_mask
689
+ attention_mask = kwargs.pop("padding_mask")
690
+
691
  output_attentions = False
692
 
693
  bsz, q_len, _ = hidden_states.size()
 
703
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
704
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
705
 
706
+ kv_seq_len = key_states.shape[-2]
707
+ if past_key_value is not None:
708
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
709
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
710
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
711
 
 
712
  if past_key_value is not None:
713
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
714
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
715
 
716
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
 
726
  # cast them back in the correct dtype just to be sure everything works as expected.
727
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
728
  # in fp32. (GemmoeRMSNorm handles it correctly)
729
+
730
  input_dtype = query_states.dtype
731
  if input_dtype == torch.float32:
 
 
732
  # Handle the case where the model is quantized
733
+ if hasattr(self.config, "_pre_quantization_dtype"):
734
  target_dtype = self.config._pre_quantization_dtype
735
+ elif torch.is_autocast_enabled():
736
+ target_dtype = torch.get_autocast_gpu_dtype()
737
  else:
738
  target_dtype = self.q_proj.weight.dtype
739
 
 
742
  f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
743
  f" {target_dtype}."
744
  )
745
+
746
  query_states = query_states.to(target_dtype)
747
  key_states = key_states.to(target_dtype)
748
  value_states = value_states.to(target_dtype)
 
751
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
752
  )
753
 
754
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
755
  attn_output = self.o_proj(attn_output)
756
 
757
  if not output_attentions:
 
793
  query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
794
  query_states, key_states, value_states, attention_mask, query_length
795
  )
796
+
797
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
798
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
799
 
 
809
  softmax_scale=softmax_scale,
810
  causal=causal,
811
  )
812
+
813
  attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
814
  else:
815
  attn_output = flash_attn_func(
 
820
 
821
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
822
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
823
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
824
+
825
  key_layer = index_first_axis(
826
  key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
827
  )
828
  value_layer = index_first_axis(
829
  value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
830
  )
 
831
  if query_length == kv_seq_len:
832
  query_layer = index_first_axis(
833
  query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
859
  class GemmoeSdpaAttention(GemmoeAttention):
860
  """
861
  Gemmoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
862
+ `GemmoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
863
  SDPA API.
864
  """
865
 
866
+ # Adapted from GemmoeAttention.forward
 
 
 
 
 
 
 
 
 
 
867
  def forward(
868
  self,
869
  hidden_states: torch.Tensor,
 
872
  past_key_value: Optional[Cache] = None,
873
  output_attentions: bool = False,
874
  use_cache: bool = False,
 
875
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
876
  if output_attentions:
877
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
878
+ logger.warning_once(
879
+ "GemmoeModel is using GemmoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
880
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
881
+ )
 
882
  return super().forward(
883
  hidden_states=hidden_states,
884
  attention_mask=attention_mask,
 
886
  past_key_value=past_key_value,
887
  output_attentions=output_attentions,
888
  use_cache=use_cache,
 
889
  )
890
+
891
  bsz, q_len, _ = hidden_states.size()
892
 
893
  query_states = self.q_proj(hidden_states)
 
898
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
899
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
900
 
901
+ kv_seq_len = key_states.shape[-2]
902
+ if past_key_value is not None:
903
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
904
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
905
+
906
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
907
 
 
908
  if past_key_value is not None:
909
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
 
910
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
911
 
912
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
913
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
914
 
915
+ if attention_mask is not None:
916
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
917
+ raise ValueError(
918
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
919
+ )
 
 
 
920
 
921
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
922
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
923
+ if query_states.device.type == "cuda" and attention_mask is not None:
924
  query_states = query_states.contiguous()
925
  key_states = key_states.contiguous()
926
  value_states = value_states.contiguous()
927
 
 
 
 
 
928
  attn_output = torch.nn.functional.scaled_dot_product_attention(
929
  query_states,
930
  key_states,
931
  value_states,
932
+ attn_mask=attention_mask,
933
  dropout_p=self.attention_dropout if self.training else 0.0,
934
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
935
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
936
  )
937
 
938
  attn_output = attn_output.transpose(1, 2).contiguous()
939
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
940
+
941
  attn_output = self.o_proj(attn_output)
942
 
943
  return attn_output, None, past_key_value
 
948
  "sdpa": GemmoeSdpaAttention,
949
  }
950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
  class GemmoeDecoderLayer(nn.Module):
953
  def __init__(self, config: GemmoeConfig, layer_idx: int):
954
  super().__init__()
955
  self.hidden_size = config.hidden_size
956
 
957
+ self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
958
 
959
+ self.mlp = GemMoE(config) if (config.n_routed_experts is not None and \
960
+ layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0) \
961
+ else GemmoeMLP(config)
962
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
963
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
964
 
 
969
  position_ids: Optional[torch.LongTensor] = None,
970
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
971
  output_attentions: Optional[bool] = False,
 
972
  use_cache: Optional[bool] = False,
 
973
  **kwargs,
974
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
975
  """
 
985
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
986
  (see `past_key_values`).
987
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
 
 
 
988
  """
989
  if "padding_mask" in kwargs:
990
  warnings.warn(
991
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
992
  )
 
993
  residual = hidden_states
994
+
995
  hidden_states = self.input_layernorm(hidden_states)
996
 
997
  # Self Attention
 
1002
  past_key_value=past_key_value,
1003
  output_attentions=output_attentions,
1004
  use_cache=use_cache,
 
1005
  **kwargs,
1006
  )
1007
  hidden_states = residual + hidden_states
 
1009
  # Fully Connected
1010
  residual = hidden_states
1011
  hidden_states = self.post_attention_layernorm(hidden_states)
1012
+ hidden_states = self.mlp(hidden_states)
1013
  hidden_states = residual + hidden_states
 
1014
 
1015
  outputs = (hidden_states,)
1016
 
 
1020
  if use_cache:
1021
  outputs += (present_key_value,)
1022
 
 
 
 
1023
  return outputs
1024
 
1025
+
1026
  GEMMOE_START_DOCSTRING = r"""
1027
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1028
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1029
+ etc.)
1030
+
1031
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1032
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1033
+ and behavior.
1034
+
1035
+ Parameters:
1036
+ config ([`GemmoeConfig`]):
1037
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1038
+ load the weights associated with the model, only the configuration. Check out the
1039
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1040
  """
1041
 
1042
  @add_start_docstrings(
 
1045
  )
1046
 
1047
  class GemmoePreTrainedModel(PreTrainedModel):
1048
+ config_class = GemmoeConfig
1049
+ base_model_prefix = "model"
1050
+ supports_gradient_checkpointing = True
1051
+ _no_split_modules = ["GemmoeDecoderLayer"]
1052
+ _skip_keys_device_placement = "past_key_values"
1053
+ _supports_flash_attn_2 = True
1054
+ _supports_sdpa = True
1055
+ _supports_cache_class = True
1056
+
1057
+ def _init_weights(self, module):
1058
+ std = self.config.initializer_range
1059
+ if isinstance(module, nn.Linear):
1060
+ module.weight.data.normal_(mean=0.0, std=std)
1061
+ if module.bias is not None:
1062
+ module.bias.data.zero_()
1063
+ elif isinstance(module, nn.Embedding):
1064
+ module.weight.data.normal_(mean=0.0, std=std)
1065
+ if module.padding_idx is not None:
1066
+ module.weight.data[module.padding_idx].zero_()
1067
+
1068
+
1069
+ Gemmoe_INPUTS_DOCSTRING = r"""
1070
+ Args:
1071
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1072
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1073
+ it.
1074
+
1075
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1076
+ [`PreTrainedTokenizer.__call__`] for details.
1077
+
1078
+ [What are input IDs?](../glossary#input-ids)
1079
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1081
+
1082
+ - 1 for tokens that are **not masked**,
1083
+ - 0 for tokens that are **masked**.
1084
+
1085
+ [What are attention masks?](../glossary#attention-mask)
1086
+
1087
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1088
+ [`PreTrainedTokenizer.__call__`] for details.
1089
+
1090
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1091
+ `past_key_values`).
1092
+
1093
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1094
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1095
+ information on the default strategy.
1096
+
1097
+ - 1 indicates the head is **not masked**,
1098
+ - 0 indicates the head is **masked**.
1099
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1100
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1101
+ config.n_positions - 1]`.
1102
+
1103
+ [What are position IDs?](../glossary#position-ids)
1104
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1105
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1106
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1107
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1108
+
1109
+ Two formats are allowed:
1110
+ - a [`~cache_utils.Cache`] instance;
1111
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1112
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1113
+ cache format.
1114
+
1115
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1116
+ legacy cache format will be returned.
1117
+
1118
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1119
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1120
+ of shape `(batch_size, sequence_length)`.
1121
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1122
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1123
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1124
+ model's internal embedding lookup matrix.
1125
+ use_cache (`bool`, *optional*):
1126
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1127
+ `past_key_values`).
1128
+ output_attentions (`bool`, *optional*):
1129
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1130
+ tensors for more detail.
1131
+ output_hidden_states (`bool`, *optional*):
1132
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1133
+ more detail.
1134
+ return_dict (`bool`, *optional*):
1135
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1136
  """
1137
 
1138
  @add_start_docstrings(
 
1141
  )
1142
 
1143
  class GemmoeModel(GemmoePreTrainedModel):
1144
+ """
1145
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmoeDecoderLayer`]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1146
 
1147
  Args:
1148
+ config: GemmoeConfig
1149
+ """
1150
+
1151
+ def __init__(self, config: GemmoeConfig):
1152
+ super().__init__(config)
1153
+ self.padding_idx = config.pad_token_id
1154
+ self.vocab_size = config.vocab_size
1155
+
1156
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1157
+ self.layers = nn.ModuleList(
1158
+ [GemmoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1159
+ )
1160
+ self._use_sdpa = config._attn_implementation == "sdpa"
1161
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1162
+ self.norm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1163
 
1164
+ self.gradient_checkpointing = False
1165
+ # Initialize weights and apply final processing
1166
+ self.post_init()
1167
 
1168
+ def get_input_embeddings(self):
1169
+ return self.embed_tokens
1170
 
1171
+ def set_input_embeddings(self, value):
1172
+ self.embed_tokens = value
1173
 
1174
+ @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
1175
+ def forward(
1176
+ self,
1177
+ input_ids: torch.LongTensor = None,
1178
+ attention_mask: Optional[torch.Tensor] = None,
1179
+ position_ids: Optional[torch.LongTensor] = None,
1180
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1181
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1182
+ use_cache: Optional[bool] = None,
1183
+ output_attentions: Optional[bool] = None,
1184
+ output_hidden_states: Optional[bool] = None,
1185
+ return_dict: Optional[bool] = None,
1186
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1187
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1188
+ output_hidden_states = (
1189
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1190
+ )
1191
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1192
+
1193
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1194
+
1195
+ # retrieve input_ids and inputs_embeds
1196
+ if input_ids is not None and inputs_embeds is not None:
1197
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1198
+ elif input_ids is not None:
1199
+ batch_size, seq_length = input_ids.shape[:2]
1200
+ elif inputs_embeds is not None:
1201
+ batch_size, seq_length = inputs_embeds.shape[:2]
1202
+ else:
1203
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1204
+
1205
+ if self.gradient_checkpointing and self.training:
1206
+ if use_cache:
1207
+ logger.warning_once(
1208
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
1209
+ )
1210
+ use_cache = False
1211
+
1212
+ past_key_values_length = 0
1213
+ if use_cache:
1214
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1215
+ if use_legacy_cache:
1216
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1217
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1218
+
1219
+ if position_ids is None:
1220
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1221
+ position_ids = torch.arange(
1222
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1223
+ )
1224
+ position_ids = position_ids.unsqueeze(0)
1225
+
1226
+ if inputs_embeds is None:
1227
+ inputs_embeds = self.embed_tokens(input_ids)
1228
+
1229
+ if self._use_flash_attention_2:
1230
+ # 2d mask is passed through the layers
1231
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1232
+ elif self._use_sdpa and not output_attentions:
1233
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1234
+ # the manual implementation that requires a 4D causal mask in all cases.
1235
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1236
+ attention_mask,
1237
+ (batch_size, seq_length),
1238
+ inputs_embeds,
1239
+ past_key_values_length,
1240
+ )
1241
+ else:
1242
+ # 4d mask is passed through the layers
1243
+ attention_mask = _prepare_4d_causal_attention_mask(
1244
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1245
+ )
1246
+
1247
+ # embed positions
1248
+ hidden_states = inputs_embeds
1249
+
1250
+ # decoder layers
1251
+ all_hidden_states = () if output_hidden_states else None
1252
+ all_self_attns = () if output_attentions else None
1253
+ next_decoder_cache = None
1254
+
1255
+ for decoder_layer in self.layers:
1256
+ if output_hidden_states:
1257
+ all_hidden_states += (hidden_states,)
1258
+
1259
+ if self.gradient_checkpointing and self.training:
1260
+ layer_outputs = self._gradient_checkpointing_func(
1261
+ decoder_layer.__call__,
1262
+ hidden_states,
1263
+ attention_mask,
1264
+ position_ids,
1265
+ past_key_values,
1266
+ output_attentions,
1267
+ use_cache,
1268
+ )
1269
+ else:
1270
+ layer_outputs = decoder_layer(
1271
+ hidden_states,
1272
+ attention_mask=attention_mask,
1273
+ position_ids=position_ids,
1274
+ past_key_value=past_key_values,
1275
+ output_attentions=output_attentions,
1276
+ use_cache=use_cache,
1277
+ )
1278
+
1279
+ hidden_states = layer_outputs[0]
1280
+
1281
+ if use_cache:
1282
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1283
+
1284
+ if output_attentions:
1285
+ all_self_attns += (layer_outputs[1],)
1286
+
1287
+ hidden_states = self.norm(hidden_states)
1288
+
1289
+ # add hidden states from the last decoder layer
1290
+ if output_hidden_states:
1291
+ all_hidden_states += (hidden_states,)
1292
+
1293
+ next_cache = None
1294
+ if use_cache:
1295
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1296
+ if not return_dict:
1297
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1298
+ return BaseModelOutputWithPast(
1299
+ last_hidden_state=hidden_states,
1300
+ past_key_values=next_cache,
1301
+ hidden_states=all_hidden_states,
1302
+ attentions=all_self_attns,
1303
+ )
1304
+
1305
+ class GemmoeForCausalLM(GemmoePreTrainedModel):
1306
  _tied_weights_keys = ["lm_head.weight"]
1307
 
1308
  def __init__(self, config):
 
1310
  self.model = GemmoeModel(config)
1311
  self.vocab_size = config.vocab_size
1312
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
1313
 
1314
  # Initialize weights and apply final processing
1315
  self.post_init()
 
1332
  def get_decoder(self):
1333
  return self.model
1334
 
1335
+ @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
1336
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1337
  def forward(
1338
  self,
1339
  input_ids: torch.LongTensor = None,
 
1345
  use_cache: Optional[bool] = None,
1346
  output_attentions: Optional[bool] = None,
1347
  output_hidden_states: Optional[bool] = None,
 
1348
  return_dict: Optional[bool] = None,
1349
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
 
1350
  r"""
1351
  Args:
1352
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1353
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
1354
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1355
+ (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
1356
 
1357
  Returns:
1358
 
 
1361
  ```python
1362
  >>> from transformers import AutoTokenizer, GemmoeForCausalLM
1363
 
1364
+ >>> model = GemmoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1365
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1366
 
1367
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1368
  >>> inputs = tokenizer(prompt, return_tensors="pt")
1369
 
1370
  >>> # Generate
1371
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1372
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1373
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1374
  ```"""
1375
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
1376
  output_hidden_states = (
1377
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1378
  )
1379
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1380
 
1381
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1382
  outputs = self.model(
1383
  input_ids=input_ids,
1384
  attention_mask=attention_mask,
 
1388
  use_cache=use_cache,
1389
  output_attentions=output_attentions,
1390
  output_hidden_states=output_hidden_states,
 
1391
  return_dict=return_dict,
 
1392
  )
1393
 
1394
  hidden_states = outputs[0]
1395
+ if self.config.pretraining_tp > 1:
1396
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1397
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1398
+ logits = torch.cat(logits, dim=-1)
1399
+ else:
1400
+ logits = self.lm_head(hidden_states)
1401
+ logits = logits.float()
1402
 
1403
  loss = None
1404
  if labels is not None:
1405
+ # Shift so that tokens < n predict n
1406
  shift_logits = logits[..., :-1, :].contiguous()
1407
  shift_labels = labels[..., 1:].contiguous()
1408
+ # Flatten the tokens
1409
  loss_fct = CrossEntropyLoss()
1410
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1411
  shift_labels = shift_labels.view(-1)
1412
+ # Enable model parallelism
1413
  shift_labels = shift_labels.to(shift_logits.device)
1414
  loss = loss_fct(shift_logits, shift_labels)
1415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1416
  if not return_dict:
1417
  output = (logits,) + outputs[1:]
 
 
1418
  return (loss,) + output if loss is not None else output
1419
 
1420
+ return CausalLMOutputWithPast(
1421
  loss=loss,
 
1422
  logits=logits,
1423
  past_key_values=outputs.past_key_values,
1424
  hidden_states=outputs.hidden_states,
1425
  attentions=outputs.attentions,
 
1426
  )
1427
 
1428
  def prepare_inputs_for_generation(
1429
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1430
  ):
 
1431
  if past_key_values is not None:
1432
  if isinstance(past_key_values, Cache):
1433
  cache_length = past_key_values.get_seq_length()
 
1437
  cache_length = past_length = past_key_values[0][0].shape[2]
1438
  max_cache_length = None
1439
 
1440
+ # Keep only the unprocessed tokens:
1441
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1442
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1443
+ # input)
1444
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1445
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1446
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1447
+ # input_ids based on the past_length.
1448
  elif past_length < input_ids.shape[1]:
1449
  input_ids = input_ids[:, past_length:]
1450
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1451
+
1452
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1453
  if (
1454
  max_cache_length is not None
1455
  and attention_mask is not None
 
1459
 
1460
  position_ids = kwargs.get("position_ids", None)
1461
  if attention_mask is not None and position_ids is None:
1462
+ # create position_ids on the fly for batch generation
1463
  position_ids = attention_mask.long().cumsum(-1) - 1
1464
  position_ids.masked_fill_(attention_mask == 0, 1)
1465
  if past_key_values:
1466
  position_ids = position_ids[:, -input_ids.shape[1] :]
1467
 
1468
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
 
 
 
 
 
 
 
 
 
1469
  if inputs_embeds is not None and past_key_values is None:
1470
  model_inputs = {"inputs_embeds": inputs_embeds}
1471
  else:
1472
+ model_inputs = {"input_ids": input_ids}
1473
 
1474
  model_inputs.update(
1475
  {
1476
+ "position_ids": position_ids,
 
1477
  "past_key_values": past_key_values,
1478
  "use_cache": kwargs.get("use_cache"),
1479
  "attention_mask": attention_mask,
1480
  }
1481
  )
 
1482
  return model_inputs
1483
 
1484
  @staticmethod
 
1511
  self.num_labels = config.num_labels
1512
  self.model = GemmoeModel(config)
1513
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1514
+
1515
  # Initialize weights and apply final processing
1516
  self.post_init()
1517
 
 
1521
  def set_input_embeddings(self, value):
1522
  self.model.embed_tokens = value
1523
 
1524
+ @add_start_docstrings_to_model_forward(Gemmoe_INPUTS_DOCSTRING)
 
1525
  def forward(
1526
  self,
1527
  input_ids: torch.LongTensor = None,
 
1535
  output_hidden_states: Optional[bool] = None,
1536
  return_dict: Optional[bool] = None,
1537
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1538
+ r"""
1539
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1540
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
1541
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1542
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
 
 
 
 
 
 
 
 
 
 
 
 
1543
  """
1544
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1545
+
1546
  transformer_outputs = self.model(
1547
  input_ids,
1548
  attention_mask=attention_mask,
 
1568
  sequence_lengths = -1
1569
  else:
1570
  if input_ids is not None:
1571
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1572
+ logits.device
1573
+ )
1574
  else:
1575
  sequence_lengths = -1
1576
 
 
1599
  elif self.config.problem_type == "multi_label_classification":
1600
  loss_fct = BCEWithLogitsLoss()
1601
  loss = loss_fct(pooled_logits, labels)
 
1602
  if not return_dict:
1603
  output = (pooled_logits,) + transformer_outputs[1:]
1604
  return ((loss,) + output) if loss is not None else output