JingzeShi commited on
Commit
c3bc525
verified
1 Parent(s): a6ad3a7

Upload DogeForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. configuration_doge.py +1 -1
  3. generation_config.json +1 -1
  4. modeling_doge.py +248 -186
config.json CHANGED
@@ -42,7 +42,7 @@
42
  "rope_theta": 10000.0,
43
  "tie_word_embeddings": true,
44
  "torch_dtype": "float32",
45
- "transformers_version": "4.48.2",
46
  "use_cache": true,
47
  "vocab_size": 32768
48
  }
 
42
  "rope_theta": 10000.0,
43
  "tie_word_embeddings": true,
44
  "torch_dtype": "float32",
45
+ "transformers_version": "4.48.3",
46
  "use_cache": true,
47
  "vocab_size": 32768
48
  }
configuration_doge.py CHANGED
@@ -144,7 +144,7 @@ class DogeConfig(PretrainedConfig):
144
  "layers.*.self_attn.q_proj": "colwise",
145
  "layers.*.self_attn.k_proj": "colwise",
146
  "layers.*.self_attn.v_proj": "colwise",
147
- "layers.*.self_attn.dt_proj": "colwise",
148
  "layers.*.self_attn.o_proj": "rowwise",
149
  "layers.*.mlp.gate_proj": "colwise",
150
  "layers.*.mlp.up_proj": "colwise",
 
144
  "layers.*.self_attn.q_proj": "colwise",
145
  "layers.*.self_attn.k_proj": "colwise",
146
  "layers.*.self_attn.v_proj": "colwise",
147
+ "layers.*.self_attn.dt_proj": "rowwise",
148
  "layers.*.self_attn.o_proj": "rowwise",
149
  "layers.*.mlp.gate_proj": "colwise",
150
  "layers.*.mlp.up_proj": "colwise",
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 2,
6
- "transformers_version": "4.48.2"
7
  }
 
3
  "bos_token_id": 0,
4
  "eos_token_id": 1,
5
  "pad_token_id": 2,
6
+ "transformers_version": "4.48.3"
7
  }
modeling_doge.py CHANGED
@@ -28,14 +28,12 @@ from typing import Callable, List, Optional, Tuple, Union
28
  import torch
29
  import torch.nn.functional as F
30
  from torch import nn
 
31
  from transformers.activations import ACT2FN
32
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
33
  from transformers.generation import GenerationMixin
34
- from transformers.modeling_outputs import (
35
- BaseModelOutputWithPast,
36
- CausalLMOutputWithPast,
37
- SequenceClassifierOutputWithPast,
38
- )
39
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
40
  from transformers.modeling_utils import PreTrainedModel
41
  from transformers.processing_utils import Unpack
@@ -47,24 +45,20 @@ from transformers.utils import (
47
  logging,
48
  replace_return_docstrings,
49
  )
50
- from transformers.utils.deprecation import deprecate_kwarg
51
-
52
  from .configuration_doge import DogeConfig
53
 
54
-
55
  if is_torch_flex_attn_available():
56
  from torch.nn.attention.flex_attention import flex_attention
57
 
58
-
59
  logger = logging.get_logger(__name__)
60
 
61
  _CONFIG_FOR_DOC = "DogeConfig"
62
 
63
 
64
- class RMSNorm(nn.Module):
65
  def __init__(self, hidden_size, eps=1e-6):
66
  """
67
- RMSNorm is equivalent to T5LayerNorm
68
  """
69
  super().__init__()
70
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -81,7 +75,7 @@ class RMSNorm(nn.Module):
81
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
82
 
83
 
84
- class Residual(nn.Module):
85
  def __init__(self, hidden_size):
86
  super().__init__()
87
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -93,8 +87,8 @@ class Residual(nn.Module):
93
  return f"{tuple(self.weight.shape)}"
94
 
95
 
96
- class RotaryEmbedding(nn.Module):
97
- def __init__(self, config: Optional[DogeConfig] = None, device=None):
98
  super().__init__()
99
  # BC: "rope_type" was originally "type"
100
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
@@ -155,9 +149,7 @@ class RotaryEmbedding(nn.Module):
155
 
156
 
157
  def rotate_half(x):
158
- """
159
- Rotates half the hidden dims of the input.
160
- """
161
  x1 = x[..., : x.shape[-1] // 2]
162
  x2 = x[..., x.shape[-1] // 2 :]
163
  return torch.cat((-x2, x1), dim=-1)
@@ -175,10 +167,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
175
  Deprecated and unused.
176
  unsqueeze_dim (`int`, *optional*, defaults to 1):
177
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
178
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k.
179
- For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].
180
- Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k.
181
- Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
 
182
  Returns:
183
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
184
  """
@@ -191,8 +184,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
191
 
192
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
193
  """
194
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
195
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
196
  """
197
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
198
  if n_rep == 1:
@@ -201,6 +194,148 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
201
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
202
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  class DogeDynamicMaskAttention(nn.Module):
205
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
206
 
@@ -208,19 +343,12 @@ class DogeDynamicMaskAttention(nn.Module):
208
  super().__init__()
209
  self.config = config
210
  self.layer_idx = layer_idx
211
- self.head_dim = config.hidden_size // config.num_attention_heads
212
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
213
  self.scaling = self.head_dim**-0.5
214
  self.attention_dropout = config.attention_dropout
215
  self.dynamic_mask_ratio = config.dynamic_mask_ratio
216
 
217
- self.ALL_ATTENTION_FUNCTIONS = {
218
- "eager": self.eager_attention_forward,
219
- "flex_attention": self.flex_attention_forward,
220
- "sdpa": self.sdpa_attention_forward,
221
- }
222
-
223
- # Q K V O projections
224
  self.q_proj = nn.Linear(
225
  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.hidden_bias
226
  )
@@ -230,7 +358,7 @@ class DogeDynamicMaskAttention(nn.Module):
230
  self.v_proj = nn.Linear(
231
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.hidden_bias
232
  )
233
- # dynamic mask for the QK^T attention score matrix
234
  self.A = nn.Parameter(torch.zeros(config.num_attention_heads))
235
  self.dt_proj = nn.Linear(
236
  config.num_key_value_heads * self.head_dim, config.num_attention_heads, bias=config.hidden_bias
@@ -247,7 +375,7 @@ class DogeDynamicMaskAttention(nn.Module):
247
  past_key_value: Optional[Cache] = None,
248
  cache_position: Optional[torch.LongTensor] = None,
249
  **kwargs,
250
- ) -> Tuple[torch.Tensor, Optional[Cache]]:
251
  input_shape = hidden_states.shape[:-1]
252
  hidden_shape = (*input_shape, -1, self.head_dim)
253
 
@@ -275,11 +403,18 @@ class DogeDynamicMaskAttention(nn.Module):
275
  attention_mask=attention_mask,
276
  )
277
 
278
- attention_interface: Callable = self.eager_attention_forward
279
  if self.config._attn_implementation != "eager":
280
- attention_interface = self.ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
 
 
 
 
 
281
 
282
- attn_output = attention_interface(
 
283
  query_states,
284
  key_states,
285
  value_states,
@@ -291,7 +426,7 @@ class DogeDynamicMaskAttention(nn.Module):
291
 
292
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
293
  attn_output = self.o_proj(attn_output)
294
- return attn_output
295
 
296
  def prepare_dynamic_mask(
297
  self,
@@ -325,109 +460,6 @@ class DogeDynamicMaskAttention(nn.Module):
325
 
326
  return attn_mask
327
 
328
- def eager_attention_forward(
329
- self,
330
- query: torch.Tensor,
331
- key: torch.Tensor,
332
- value: torch.Tensor,
333
- attention_mask: Optional[torch.Tensor],
334
- scaling: float,
335
- dropout: float = 0.0,
336
- **kwargs,
337
- ) -> torch.Tensor:
338
- key_states = repeat_kv(key, self.num_key_value_groups)
339
- value_states = repeat_kv(value, self.num_key_value_groups)
340
-
341
- # compute attention scores matrix
342
- attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling
343
- if attention_mask is not None:
344
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
345
- attn_weights = attn_weights + causal_mask
346
-
347
- # upcast attention scores to fp32
348
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
349
- attn_weights = F.dropout(attn_weights, p=dropout, training=self.training)
350
-
351
- # apply attention scores to value states
352
- attn_output = torch.matmul(attn_weights, value_states)
353
- attn_output = attn_output.transpose(1, 2).contiguous()
354
- return attn_output
355
-
356
- def sdpa_attention_forward(
357
- self,
358
- query: torch.Tensor,
359
- key: torch.Tensor,
360
- value: torch.Tensor,
361
- attention_mask: Optional[torch.Tensor],
362
- scaling: float,
363
- dropout: float = 0.0,
364
- **kwargs,
365
- ) -> torch.Tensor:
366
- key = repeat_kv(key, self.num_key_value_groups)
367
- value = repeat_kv(value, self.num_key_value_groups)
368
-
369
- causal_mask = attention_mask
370
- if attention_mask is not None:
371
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
372
-
373
- # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
374
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
375
- query = query.contiguous()
376
- key = key.contiguous()
377
- value = value.contiguous()
378
-
379
- # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
380
- torch.backends.cuda.enable_cudnn_sdp(False)
381
- attn_output = F.scaled_dot_product_attention(
382
- query,
383
- key,
384
- value,
385
- attn_mask=causal_mask,
386
- dropout_p=dropout,
387
- scale=scaling,
388
- )
389
- attn_output = attn_output.transpose(1, 2).contiguous()
390
- return attn_output
391
-
392
- def flex_attention_forward(
393
- self,
394
- query: torch.Tensor,
395
- key: torch.Tensor,
396
- value: torch.Tensor,
397
- attention_mask: Optional[torch.Tensor],
398
- scaling: float,
399
- dropout: float = 0.0,
400
- **kwargs,
401
- ) -> torch.Tensor:
402
- key = repeat_kv(key, self.num_key_value_groups)
403
- value = repeat_kv(value, self.num_key_value_groups)
404
-
405
- causal_mask = attention_mask
406
- if attention_mask is not None:
407
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
408
-
409
- # TODO: flex_attention: As of pytorch 2.5.1, captured buffers that require grad are not yet supported.
410
- # NOTE: So we only use flex_attention in inference mode.
411
- def causal_mod(score, batch, head, q_idx, kv_idx):
412
- score = score + causal_mask[batch][0][q_idx][kv_idx]
413
- return score
414
-
415
- def dynamic_mod(score, batch, head, q_idx, kv_idx):
416
- score = score + causal_mask[batch][head][q_idx][kv_idx]
417
- return score
418
-
419
- mask_mod = causal_mod if self.is_causal else dynamic_mod
420
-
421
- attn_output = flex_attention(
422
- query,
423
- key,
424
- value,
425
- score_mod=mask_mod,
426
- scale=scaling,
427
- )
428
- attn_output = attn_output.transpose(1, 2).contiguous()
429
- return attn_output
430
-
431
 
432
  class DogeMLP(nn.Module):
433
  def __init__(self, config: DogeConfig):
@@ -464,8 +496,8 @@ class DogeCDMoE(DogeMLP):
464
  self.num_keys = int(math.sqrt(self.num_cdmoe_experts))
465
 
466
  # queries and keys for retrieval experts
467
- self.queries = nn.Linear(self.hidden_dim, self.num_cdmoe_heads * self.expert_retrieval_dim, bias=False)
468
- self.keys = nn.Parameter(torch.zeros(self.num_cdmoe_heads, self.num_keys, 2, self.expert_retrieval_dim // 2))
469
 
470
  # experts
471
  self.down_embed = nn.Embedding(self.num_cdmoe_experts, self.hidden_dim)
@@ -478,13 +510,15 @@ class DogeCDMoE(DogeMLP):
478
  ) -> torch.Tensor:
479
  bsz, seq_len, _ = hidden_states.shape
480
 
481
- # get similarity with queries and keys
482
- queries = self.queries(hidden_states)
483
- queries = queries.view(bsz, seq_len, 2, self.num_cdmoe_heads, -1).permute(2, 0, 1, 3, 4)
484
- sim = torch.einsum("p b t h n, h k p n -> p b t h k", queries, self.keys)
 
 
485
 
486
- # get experts with the highest similarity
487
- (scores_x, scores_y), (indices_x, indices_y) = sim.topk(self.num_cdmoe_experts_per_head, dim=-1)
488
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
489
  all_scores = all_scores.view(*scores_x.shape[:-1], -1)
490
  all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
@@ -495,9 +529,9 @@ class DogeCDMoE(DogeMLP):
495
  up_embed = self.up_embed(indices)
496
 
497
  # mix experts states with cross domain states
498
- experts_weights = torch.einsum("b t d, b t h k d -> b t h k", hidden_states, down_embed)
499
  experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
500
- experts_states = torch.einsum("b t h k, b t h k d -> b t d", experts_weights, up_embed)
501
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
502
  hidden_states = hidden_states + experts_states
503
  return hidden_states
@@ -508,13 +542,13 @@ class DogeDecoderLayer(nn.Module):
508
  super().__init__()
509
  self.hidden_dropout = config.hidden_dropout
510
 
511
- self.pre_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
512
  self.self_attn = DogeDynamicMaskAttention(config=config, layer_idx=layer_idx)
513
- self.pre_residual = Residual(config.hidden_size)
514
 
515
- self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
516
  self.feed_forward = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
517
- self.post_residual = Residual(config.hidden_size)
518
 
519
  def forward(
520
  self,
@@ -531,11 +565,13 @@ class DogeDecoderLayer(nn.Module):
531
  # sequence transformation
532
  residual = hidden_states
533
  hidden_states = self.pre_layernorm(hidden_states)
534
- hidden_states = self.self_attn(
535
  hidden_states=hidden_states,
536
  attention_mask=attention_mask,
537
  position_ids=position_ids,
538
  past_key_value=past_key_value,
 
 
539
  cache_position=cache_position,
540
  position_embeddings=position_embeddings,
541
  **kwargs,
@@ -586,7 +622,7 @@ class DogePreTrainedModel(PreTrainedModel):
586
  _no_split_modules = ["DogeDecoderLayer"]
587
  _skip_keys_device_placement = ["past_key_values"]
588
  _supports_sdpa = True
589
- _supports_flex_attn = True
590
  _supports_cache_class = True
591
  _supports_quantized_cache = True
592
  _supports_static_cache = True
@@ -697,11 +733,11 @@ class DogeModel(DogePreTrainedModel):
697
  self.vocab_size = config.vocab_size
698
 
699
  self.word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
700
- self.rotary_emb = RotaryEmbedding(config)
701
  self.layers = nn.ModuleList(
702
  [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
703
  )
704
- self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
705
  self.gradient_checkpointing = False
706
 
707
  # Initialize weights and apply final processing
@@ -828,9 +864,27 @@ class DogeModel(DogePreTrainedModel):
828
  past_key_values: Cache,
829
  output_attentions: bool,
830
  ):
 
 
 
 
 
 
 
 
831
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
832
  using_static_cache = isinstance(past_key_values, StaticCache)
833
 
 
 
 
 
 
 
 
 
 
 
834
  dtype, device = input_tensor.dtype, input_tensor.device
835
  sequence_length = input_tensor.shape[1]
836
  if using_static_cache:
@@ -842,9 +896,9 @@ class DogeModel(DogePreTrainedModel):
842
  else past_seen_tokens + sequence_length + 1
843
  )
844
 
845
- # in case the provided `attention` mask is 2D, we generate a causal mask here (4D).
846
  causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
847
- attention_mask=attention_mask,
848
  sequence_length=sequence_length,
849
  target_length=target_length,
850
  dtype=dtype,
@@ -853,17 +907,29 @@ class DogeModel(DogePreTrainedModel):
853
  batch_size=input_tensor.shape[0],
854
  )
855
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  return causal_mask
857
 
858
  @staticmethod
859
  def _prepare_4d_causal_attention_mask_with_cache_position(
860
- attention_mask: torch.Tensor = None,
861
- sequence_length: int = None,
862
- target_length: int = None,
863
- dtype: torch.dtype = None,
864
- device: torch.device = None,
865
- cache_position: torch.Tensor = None,
866
- batch_size: int = None,
867
  **kwargs,
868
  ):
869
  """
@@ -894,10 +960,7 @@ class DogeModel(DogePreTrainedModel):
894
  else:
895
  min_dtype = torch.finfo(dtype).min
896
  causal_mask = torch.full(
897
- (sequence_length, target_length),
898
- fill_value=min_dtype,
899
- dtype=dtype,
900
- device=device,
901
  )
902
  if sequence_length != 1:
903
  causal_mask = torch.triu(causal_mask, diagonal=1)
@@ -915,9 +978,6 @@ class DogeModel(DogePreTrainedModel):
915
  return causal_mask
916
 
917
 
918
- class KwargsForCausalLM(LossKwargs): ...
919
-
920
-
921
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
922
  _tied_weights_keys = ["lm_head.weight"]
923
  _tp_plan = {"lm_head": "colwise_rep"}
@@ -950,7 +1010,6 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
950
  def set_decoder(self, decoder):
951
  self.model = decoder
952
 
953
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
954
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
955
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
956
  def forward(
@@ -966,8 +1025,8 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
966
  output_hidden_states: Optional[bool] = None,
967
  return_dict: Optional[bool] = None,
968
  cache_position: Optional[torch.LongTensor] = None,
969
- logits_to_keep: int = 0,
970
- **kwargs: Unpack[KwargsForCausalLM],
971
  ) -> Union[Tuple, CausalLMOutputWithPast]:
972
  r"""
973
  Args:
@@ -1121,17 +1180,20 @@ class DogeForSequenceClassification(DogePreTrainedModel):
1121
  if self.config.pad_token_id is None and batch_size != 1:
1122
  raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1123
  if self.config.pad_token_id is None:
1124
- sequence_lengths = -1
 
 
 
 
 
1125
  else:
1126
- if input_ids is not None:
1127
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1128
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1129
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1130
- sequence_lengths = sequence_lengths.to(logits.device)
1131
- else:
1132
- sequence_lengths = -1
1133
 
1134
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1135
 
1136
  loss = None
1137
  if labels is not None:
 
28
  import torch
29
  import torch.nn.functional as F
30
  from torch import nn
31
+
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
  from transformers.generation import GenerationMixin
35
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
 
 
 
37
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
  from transformers.modeling_utils import PreTrainedModel
39
  from transformers.processing_utils import Unpack
 
45
  logging,
46
  replace_return_docstrings,
47
  )
 
 
48
  from .configuration_doge import DogeConfig
49
 
 
50
  if is_torch_flex_attn_available():
51
  from torch.nn.attention.flex_attention import flex_attention
52
 
 
53
  logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "DogeConfig"
56
 
57
 
58
+ class DogeRMSNorm(nn.Module):
59
  def __init__(self, hidden_size, eps=1e-6):
60
  """
61
+ DogeRMSNorm is equivalent to T5LayerNorm
62
  """
63
  super().__init__()
64
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
75
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
76
 
77
 
78
+ class DogeResidual(nn.Module):
79
  def __init__(self, hidden_size):
80
  super().__init__()
81
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
87
  return f"{tuple(self.weight.shape)}"
88
 
89
 
90
+ class DogeRotaryEmbedding(nn.Module):
91
+ def __init__(self, config: DogeConfig, device=None):
92
  super().__init__()
93
  # BC: "rope_type" was originally "type"
94
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
 
149
 
150
 
151
  def rotate_half(x):
152
+ """Rotates half the hidden dims of the input."""
 
 
153
  x1 = x[..., : x.shape[-1] // 2]
154
  x2 = x[..., x.shape[-1] // 2 :]
155
  return torch.cat((-x2, x1), dim=-1)
 
167
  Deprecated and unused.
168
  unsqueeze_dim (`int`, *optional*, defaults to 1):
169
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
170
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
171
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
172
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
173
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
174
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
175
  Returns:
176
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
177
  """
 
184
 
185
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
186
  """
187
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
188
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
189
  """
190
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
191
  if n_rep == 1:
 
194
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
195
 
196
 
197
+ def eager_attention_forward(
198
+ module: nn.Module,
199
+ query: torch.Tensor,
200
+ key: torch.Tensor,
201
+ value: torch.Tensor,
202
+ attention_mask: Optional[torch.Tensor],
203
+ scaling: float,
204
+ dropout: float = 0.0,
205
+ **kwargs,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ key_states = repeat_kv(key, module.num_key_value_groups)
208
+ value_states = repeat_kv(value, module.num_key_value_groups)
209
+
210
+ attn_weights = torch.matmul(query, key_states.transpose(-1, -2)) * scaling
211
+ if attention_mask is not None:
212
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
213
+ attn_weights = attn_weights + causal_mask
214
+
215
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
216
+ attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
217
+ attn_output = torch.matmul(attn_weights, value_states)
218
+ attn_output = attn_output.transpose(1, 2).contiguous()
219
+
220
+ return attn_output, attn_weights
221
+
222
+
223
+ def sdpa_attention_forward(
224
+ module: nn.Module,
225
+ query: torch.Tensor,
226
+ key: torch.Tensor,
227
+ value: torch.Tensor,
228
+ attention_mask: Optional[torch.Tensor],
229
+ dropout: float = 0.0,
230
+ scaling: Optional[float] = None,
231
+ is_causal: Optional[bool] = None,
232
+ **kwargs,
233
+ ) -> Tuple[torch.Tensor, None]:
234
+ key = repeat_kv(key, module.num_key_value_groups)
235
+ value = repeat_kv(value, module.num_key_value_groups)
236
+
237
+ causal_mask = attention_mask
238
+ if attention_mask is not None:
239
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
240
+
241
+ # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
242
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
243
+ query = query.contiguous()
244
+ key = key.contiguous()
245
+ value = value.contiguous()
246
+
247
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
248
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
249
+ if is_causal is None:
250
+ is_causal = causal_mask is None and query.shape[2] > 1
251
+
252
+ # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
253
+ # We convert it to a bool for the SDPA kernel that only accepts bools.
254
+ if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
255
+ is_causal = is_causal.item()
256
+
257
+ # NOTE: As of pytorch 2.5.1, SDPA backward pass of cuDNN is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
258
+ torch.backends.cuda.enable_cudnn_sdp(False)
259
+ attn_output = F.scaled_dot_product_attention(
260
+ query=query,
261
+ key=key,
262
+ value=value,
263
+ attn_mask=causal_mask,
264
+ dropout_p=dropout,
265
+ scale=scaling,
266
+ is_causal=is_causal,
267
+ )
268
+ attn_output = attn_output.transpose(1, 2).contiguous()
269
+
270
+ return attn_output, None
271
+
272
+
273
+ def flex_attention_forward(
274
+ module: nn.Module,
275
+ query: torch.Tensor,
276
+ key: torch.Tensor,
277
+ value: torch.Tensor,
278
+ attention_mask: Optional[torch.Tensor],
279
+ scaling: Optional[float] = None,
280
+ is_causal: Optional[bool] = None,
281
+ softcap: Optional[float] = None,
282
+ head_mask: Optional[torch.Tensor] = None,
283
+ **kwargs,
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ causal_mask = attention_mask
286
+ if attention_mask is not None:
287
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
288
+
289
+ if is_causal is None:
290
+ is_causal = causal_mask is None and query.shape[2] > 1
291
+
292
+ def causal_mod(score, batch, head, q_idx, kv_idx):
293
+ if softcap is not None:
294
+ score = softcap * torch.tanh(score / softcap)
295
+ if causal_mask is not None:
296
+ score = score + causal_mask[batch][0][q_idx][kv_idx]
297
+ if head_mask is not None:
298
+ score = score + head_mask[batch][head][0][0]
299
+ return score
300
+
301
+ def dynamic_mod(score, batch, head, q_idx, kv_idx):
302
+ if softcap is not None:
303
+ score = softcap * torch.tanh(score / softcap)
304
+ if causal_mask is not None:
305
+ score = score + causal_mask[batch][head][q_idx][kv_idx]
306
+ if head_mask is not None:
307
+ score = score + head_mask[batch][head][0][0]
308
+ return score
309
+
310
+ # TODO: flex_attention: As of pytorch 2.5.1, captured buffers that require grad are not yet supported.
311
+ # NOTE: So we only use flex_attention in inference mode.
312
+ mask_mod = causal_mod if is_causal or module.training else dynamic_mod
313
+
314
+ attn_output, attention_weights = flex_attention(
315
+ query=query,
316
+ key=key,
317
+ value=value,
318
+ score_mod=mask_mod,
319
+ enable_gqa=True,
320
+ scale=scaling,
321
+ # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
322
+ # For simplification, we thus always return it as no additional computations are introduced.
323
+ return_lse=True,
324
+ )
325
+ # lse is returned in float32
326
+ attention_weights = attention_weights.to(value.dtype)
327
+ attn_output = attn_output.transpose(1, 2).contiguous()
328
+
329
+ return attn_output, attention_weights
330
+
331
+
332
+ ALL_ATTENTION_FUNCTIONS = {
333
+ "eager": eager_attention_forward,
334
+ "sdpa": sdpa_attention_forward,
335
+ "flex_attention": flex_attention_forward,
336
+ }
337
+
338
+
339
  class DogeDynamicMaskAttention(nn.Module):
340
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
341
 
 
343
  super().__init__()
344
  self.config = config
345
  self.layer_idx = layer_idx
346
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
347
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
348
  self.scaling = self.head_dim**-0.5
349
  self.attention_dropout = config.attention_dropout
350
  self.dynamic_mask_ratio = config.dynamic_mask_ratio
351
 
 
 
 
 
 
 
 
352
  self.q_proj = nn.Linear(
353
  config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.hidden_bias
354
  )
 
358
  self.v_proj = nn.Linear(
359
  config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.hidden_bias
360
  )
361
+ # dynamic mask for the QK^T attention weights matrix
362
  self.A = nn.Parameter(torch.zeros(config.num_attention_heads))
363
  self.dt_proj = nn.Linear(
364
  config.num_key_value_heads * self.head_dim, config.num_attention_heads, bias=config.hidden_bias
 
375
  past_key_value: Optional[Cache] = None,
376
  cache_position: Optional[torch.LongTensor] = None,
377
  **kwargs,
378
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
379
  input_shape = hidden_states.shape[:-1]
380
  hidden_shape = (*input_shape, -1, self.head_dim)
381
 
 
403
  attention_mask=attention_mask,
404
  )
405
 
406
+ attention_interface: Callable = eager_attention_forward
407
  if self.config._attn_implementation != "eager":
408
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
409
+ logger.warning_once(
410
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
411
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
412
+ )
413
+ else:
414
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
415
 
416
+ attn_output, attn_weights = attention_interface(
417
+ self,
418
  query_states,
419
  key_states,
420
  value_states,
 
426
 
427
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
428
  attn_output = self.o_proj(attn_output)
429
+ return attn_output, attn_weights
430
 
431
  def prepare_dynamic_mask(
432
  self,
 
460
 
461
  return attn_mask
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
  class DogeMLP(nn.Module):
465
  def __init__(self, config: DogeConfig):
 
496
  self.num_keys = int(math.sqrt(self.num_cdmoe_experts))
497
 
498
  # queries and keys for retrieval experts
499
+ self.queries_proj = nn.Linear(self.hidden_dim, self.num_cdmoe_heads * self.expert_retrieval_dim, bias=False)
500
+ self.keys = nn.Parameter(torch.zeros(self.num_cdmoe_heads, self.expert_retrieval_dim, self.num_keys))
501
 
502
  # experts
503
  self.down_embed = nn.Embedding(self.num_cdmoe_experts, self.hidden_dim)
 
510
  ) -> torch.Tensor:
511
  bsz, seq_len, _ = hidden_states.shape
512
 
513
+ # get routing weights with queries and keys
514
+ queries = self.queries_proj(hidden_states)
515
+ queries = queries.view(2, self.num_cdmoe_heads, bsz * seq_len, -1)
516
+ keys = self.keys.view(2, self.num_cdmoe_heads, -1, self.num_keys)
517
+ routing_weights = torch.matmul(queries, keys)
518
+ routing_weights = routing_weights.transpose(-2, -3).view(2, bsz, seq_len, self.num_cdmoe_heads, self.num_keys)
519
 
520
+ # get experts with the highest routing weights
521
+ (scores_x, scores_y), (indices_x, indices_y) = routing_weights.topk(self.num_cdmoe_experts_per_head, dim=-1)
522
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
523
  all_scores = all_scores.view(*scores_x.shape[:-1], -1)
524
  all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
 
529
  up_embed = self.up_embed(indices)
530
 
531
  # mix experts states with cross domain states
532
+ experts_weights = torch.sum(hidden_states[:, :, None, None, :] * down_embed, dim=-1)
533
  experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
534
+ experts_states = torch.sum(experts_weights[:, :, :, :, None] * up_embed, dim=(-2, -3))
535
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
536
  hidden_states = hidden_states + experts_states
537
  return hidden_states
 
542
  super().__init__()
543
  self.hidden_dropout = config.hidden_dropout
544
 
545
+ self.pre_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
546
  self.self_attn = DogeDynamicMaskAttention(config=config, layer_idx=layer_idx)
547
+ self.pre_residual = DogeResidual(config.hidden_size)
548
 
549
+ self.post_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
550
  self.feed_forward = DogeMLP(config) if not config.is_moe else DogeCDMoE(config)
551
+ self.post_residual = DogeResidual(config.hidden_size)
552
 
553
  def forward(
554
  self,
 
565
  # sequence transformation
566
  residual = hidden_states
567
  hidden_states = self.pre_layernorm(hidden_states)
568
+ hidden_states, self_attn_weights = self.self_attn(
569
  hidden_states=hidden_states,
570
  attention_mask=attention_mask,
571
  position_ids=position_ids,
572
  past_key_value=past_key_value,
573
+ output_attentions=output_attentions,
574
+ use_cache=use_cache,
575
  cache_position=cache_position,
576
  position_embeddings=position_embeddings,
577
  **kwargs,
 
622
  _no_split_modules = ["DogeDecoderLayer"]
623
  _skip_keys_device_placement = ["past_key_values"]
624
  _supports_sdpa = True
625
+ # _supports_flex_attn = True # TODO: enable this when flex_attention is fully supported
626
  _supports_cache_class = True
627
  _supports_quantized_cache = True
628
  _supports_static_cache = True
 
733
  self.vocab_size = config.vocab_size
734
 
735
  self.word_embed = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
736
+ self.rotary_emb = DogeRotaryEmbedding(config)
737
  self.layers = nn.ModuleList(
738
  [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
739
  )
740
+ self.final_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
741
  self.gradient_checkpointing = False
742
 
743
  # Initialize weights and apply final processing
 
864
  past_key_values: Cache,
865
  output_attentions: bool,
866
  ):
867
+ if self.config._attn_implementation == "flash_attention_2":
868
+ if attention_mask is not None and (attention_mask == 0.0).any():
869
+ return attention_mask
870
+ return None
871
+
872
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
873
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
874
+ # to infer the attention mask.
875
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
876
  using_static_cache = isinstance(past_key_values, StaticCache)
877
 
878
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
879
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
880
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
881
+ attention_mask,
882
+ inputs_embeds=input_tensor,
883
+ past_key_values_length=past_seen_tokens,
884
+ is_training=self.training,
885
+ ):
886
+ return None
887
+
888
  dtype, device = input_tensor.dtype, input_tensor.device
889
  sequence_length = input_tensor.shape[1]
890
  if using_static_cache:
 
896
  else past_seen_tokens + sequence_length + 1
897
  )
898
 
899
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
900
  causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
901
+ attention_mask,
902
  sequence_length=sequence_length,
903
  target_length=target_length,
904
  dtype=dtype,
 
907
  batch_size=input_tensor.shape[0],
908
  )
909
 
910
+ if (
911
+ self.config._attn_implementation == "sdpa"
912
+ and attention_mask is not None
913
+ and attention_mask.device.type in ["cuda", "xpu"]
914
+ and not output_attentions
915
+ ):
916
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
917
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
918
+ # Details: https://github.com/pytorch/pytorch/issues/110213
919
+ min_dtype = torch.finfo(dtype).min
920
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
921
+
922
  return causal_mask
923
 
924
  @staticmethod
925
  def _prepare_4d_causal_attention_mask_with_cache_position(
926
+ attention_mask: torch.Tensor,
927
+ sequence_length: int,
928
+ target_length: int,
929
+ dtype: torch.dtype,
930
+ device: torch.device,
931
+ cache_position: torch.Tensor,
932
+ batch_size: int,
933
  **kwargs,
934
  ):
935
  """
 
960
  else:
961
  min_dtype = torch.finfo(dtype).min
962
  causal_mask = torch.full(
963
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
 
 
 
964
  )
965
  if sequence_length != 1:
966
  causal_mask = torch.triu(causal_mask, diagonal=1)
 
978
  return causal_mask
979
 
980
 
 
 
 
981
  class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
982
  _tied_weights_keys = ["lm_head.weight"]
983
  _tp_plan = {"lm_head": "colwise_rep"}
 
1010
  def set_decoder(self, decoder):
1011
  self.model = decoder
1012
 
 
1013
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
1014
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1015
  def forward(
 
1025
  output_hidden_states: Optional[bool] = None,
1026
  return_dict: Optional[bool] = None,
1027
  cache_position: Optional[torch.LongTensor] = None,
1028
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1029
+ **kwargs: Unpack[LossKwargs],
1030
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1031
  r"""
1032
  Args:
 
1180
  if self.config.pad_token_id is None and batch_size != 1:
1181
  raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1182
  if self.config.pad_token_id is None:
1183
+ last_non_pad_token = -1
1184
+ elif input_ids is not None:
1185
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
1186
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
1187
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
1188
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
1189
  else:
1190
+ last_non_pad_token = -1
1191
+ logger.warning_once(
1192
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1193
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1194
+ )
 
 
1195
 
1196
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
1197
 
1198
  loss = None
1199
  if labels is not None: