llllvvuu commited on
Commit
4deb747
1 Parent(s): 944fec2

fix: modeling_deepseek.py should use `deepseek` instead of `deepseek_v2` architecture

Browse files

I have copied the file from https://huggingface.co/deepseek-ai/deepseek-moe-16b-chat/edit/main/modeling_deepseek.py

I believe that is the correct one since the model weight dict has matching keys (using the original self_attn architecture)

Files changed (1) hide show
  1. modeling_deepseek.py +374 -730
modeling_deepseek.py CHANGED
@@ -5,7 +5,7 @@
5
  # and OPT implementations in this library. It has been modified from its
6
  # original forms to accommodate minor architectural differences compared
7
  # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
11
  # You may obtain a copy of the License at
@@ -34,17 +34,11 @@ from transformers.modeling_attn_mask_utils import (
34
  AttentionMaskConverter,
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
 
37
  )
38
- from transformers.modeling_outputs import (
39
- BaseModelOutputWithPast,
40
- CausalLMOutputWithPast,
41
- SequenceClassifierOutputWithPast,
42
- )
43
  from transformers.modeling_utils import PreTrainedModel
44
- from transformers.pytorch_utils import (
45
- ALL_LAYERNORM_LAYERS,
46
- is_torch_greater_or_equal_than_1_13,
47
- )
48
  from transformers.utils import (
49
  add_start_docstrings,
50
  add_start_docstrings_to_model_forward,
@@ -54,9 +48,8 @@ from transformers.utils import (
54
  replace_return_docstrings,
55
  )
56
  from transformers.utils.import_utils import is_torch_fx_available
57
- from .configuration_deepseek import DeepseekV2Config
58
- import torch.distributed as dist
59
- import numpy as np
60
 
61
  if is_flash_attn_2_available():
62
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -74,16 +67,14 @@ if is_torch_fx_available():
74
 
75
  logger = logging.get_logger(__name__)
76
 
77
- _CONFIG_FOR_DOC = "DeepseekV2Config"
78
 
79
 
80
  def _get_unpad_data(attention_mask):
81
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
82
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
83
  max_seqlen_in_batch = seqlens_in_batch.max().item()
84
- cu_seqlens = F.pad(
85
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
86
- )
87
  return (
88
  indices,
89
  cu_seqlens,
@@ -91,10 +82,28 @@ def _get_unpad_data(attention_mask):
91
  )
92
 
93
 
94
- class DeepseekV2RMSNorm(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def __init__(self, hidden_size, eps=1e-6):
96
  """
97
- DeepseekV2RMSNorm is equivalent to T5LayerNorm
98
  """
99
  super().__init__()
100
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -108,34 +117,29 @@ class DeepseekV2RMSNorm(nn.Module):
108
  return self.weight * hidden_states.to(input_dtype)
109
 
110
 
111
- ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
112
 
113
 
114
- class DeepseekV2RotaryEmbedding(nn.Module):
115
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
116
  super().__init__()
117
 
118
  self.dim = dim
119
  self.max_position_embeddings = max_position_embeddings
120
  self.base = base
121
- inv_freq = 1.0 / (
122
- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
123
- )
124
  self.register_buffer("inv_freq", inv_freq, persistent=False)
125
 
126
  # Build here to make `torch.jit.trace` work.
127
  self._set_cos_sin_cache(
128
- seq_len=max_position_embeddings,
129
- device=self.inv_freq.device,
130
- dtype=torch.get_default_dtype(),
131
  )
132
  self.max_seq_len_cached = None
133
 
 
134
  def _set_cos_sin_cache(self, seq_len, device, dtype):
135
  self.max_seq_len_cached = seq_len
136
- t = torch.arange(
137
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
138
- )
139
 
140
  freqs = torch.outer(t, self.inv_freq.to(t.device))
141
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -154,26 +158,17 @@ class DeepseekV2RotaryEmbedding(nn.Module):
154
  )
155
 
156
 
157
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
158
- class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
159
- """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
160
 
161
- def __init__(
162
- self,
163
- dim,
164
- max_position_embeddings=2048,
165
- base=10000,
166
- device=None,
167
- scaling_factor=1.0,
168
- ):
169
  self.scaling_factor = scaling_factor
170
  super().__init__(dim, max_position_embeddings, base, device)
171
 
172
  def _set_cos_sin_cache(self, seq_len, device, dtype):
173
  self.max_seq_len_cached = seq_len
174
- t = torch.arange(
175
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
176
- )
177
  t = t / self.scaling_factor
178
 
179
  freqs = torch.outer(t, self.inv_freq)
@@ -183,18 +178,11 @@ class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
183
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
184
 
185
 
186
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
187
- class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
188
- """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
189
 
190
- def __init__(
191
- self,
192
- dim,
193
- max_position_embeddings=2048,
194
- base=10000,
195
- device=None,
196
- scaling_factor=1.0,
197
- ):
198
  self.scaling_factor = scaling_factor
199
  super().__init__(dim, max_position_embeddings, base, device)
200
 
@@ -203,17 +191,12 @@ class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
203
 
204
  if seq_len > self.max_position_embeddings:
205
  base = self.base * (
206
- (self.scaling_factor * seq_len / self.max_position_embeddings)
207
- - (self.scaling_factor - 1)
208
  ) ** (self.dim / (self.dim - 2))
209
- inv_freq = 1.0 / (
210
- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
211
- )
212
  self.register_buffer("inv_freq", inv_freq, persistent=False)
213
 
214
- t = torch.arange(
215
- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
216
- )
217
 
218
  freqs = torch.outer(t, self.inv_freq)
219
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -222,111 +205,6 @@ class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
222
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
223
 
224
 
225
- # Inverse dim formula to find dim based on number of rotations
226
- def yarn_find_correction_dim(
227
- num_rotations, dim, base=10000, max_position_embeddings=2048
228
- ):
229
- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
230
- 2 * math.log(base)
231
- )
232
-
233
-
234
- # Find dim range bounds based on rotations
235
- def yarn_find_correction_range(
236
- low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
237
- ):
238
- low = math.floor(
239
- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
240
- )
241
- high = math.ceil(
242
- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
243
- )
244
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
245
-
246
-
247
- def yarn_get_mscale(scale=1, mscale=1):
248
- if scale <= 1:
249
- return 1.0
250
- return 0.1 * mscale * math.log(scale) + 1.0
251
-
252
-
253
- def yarn_linear_ramp_mask(min, max, dim):
254
- if min == max:
255
- max += 0.001 # Prevent singularity
256
-
257
- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
258
- ramp_func = torch.clamp(linear_func, 0, 1)
259
- return ramp_func
260
-
261
-
262
- class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
263
-
264
- def __init__(
265
- self,
266
- dim,
267
- max_position_embeddings=2048,
268
- base=10000,
269
- device=None,
270
- scaling_factor=1.0,
271
- original_max_position_embeddings=4096,
272
- beta_fast=32,
273
- beta_slow=1,
274
- mscale=1,
275
- mscale_all_dim=0,
276
- ):
277
- self.scaling_factor = scaling_factor
278
- self.original_max_position_embeddings = original_max_position_embeddings
279
- self.beta_fast = beta_fast
280
- self.beta_slow = beta_slow
281
- self.mscale = mscale
282
- self.mscale_all_dim = mscale_all_dim
283
- super().__init__(dim, max_position_embeddings, base, device)
284
-
285
- def _set_cos_sin_cache(self, seq_len, device, dtype):
286
- self.max_seq_len_cached = seq_len
287
- dim = self.dim
288
-
289
- freq_extra = 1.0 / (
290
- self.base
291
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
292
- )
293
- freq_inter = 1.0 / (
294
- self.scaling_factor
295
- * self.base
296
- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
297
- )
298
-
299
- low, high = yarn_find_correction_range(
300
- self.beta_fast,
301
- self.beta_slow,
302
- dim,
303
- self.base,
304
- self.original_max_position_embeddings,
305
- )
306
- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
307
- device=device, dtype=torch.float32
308
- )
309
- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
310
- self.register_buffer("inv_freq", inv_freq, persistent=False)
311
-
312
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
313
-
314
- freqs = torch.outer(t, inv_freq)
315
-
316
- _mscale = float(
317
- yarn_get_mscale(self.scaling_factor, self.mscale)
318
- / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
319
- )
320
-
321
- emb = torch.cat((freqs, freqs), dim=-1)
322
- self.register_buffer(
323
- "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
324
- )
325
- self.register_buffer(
326
- "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
327
- )
328
-
329
-
330
  # Copied from transformers.models.llama.modeling_llama.rotate_half
331
  def rotate_half(x):
332
  """Rotates half the hidden dims of the input."""
@@ -359,26 +237,17 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
359
  """
360
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
361
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
362
-
363
- b, h, s, d = q.shape
364
- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
365
-
366
- b, h, s, d = k.shape
367
- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
368
-
369
  q_embed = (q * cos) + (rotate_half(q) * sin)
370
  k_embed = (k * cos) + (rotate_half(k) * sin)
371
  return q_embed, k_embed
372
 
373
 
374
- class DeepseekV2MLP(nn.Module):
375
- def __init__(self, config, hidden_size=None, intermediate_size=None):
376
  super().__init__()
377
  self.config = config
378
  self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
379
- self.intermediate_size = (
380
- config.intermediate_size if intermediate_size is None else intermediate_size
381
- )
382
 
383
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
384
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
@@ -386,7 +255,25 @@ class DeepseekV2MLP(nn.Module):
386
  self.act_fn = ACT2FN[config.hidden_act]
387
 
388
  def forward(self, x):
389
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  return down_proj
391
 
392
 
@@ -396,75 +283,39 @@ class MoEGate(nn.Module):
396
  self.config = config
397
  self.top_k = config.num_experts_per_tok
398
  self.n_routed_experts = config.n_routed_experts
399
- self.routed_scaling_factor = config.routed_scaling_factor
400
  self.scoring_func = config.scoring_func
401
  self.alpha = config.aux_loss_alpha
402
  self.seq_aux = config.seq_aux
403
- self.topk_method = config.topk_method
404
- self.n_group = config.n_group
405
- self.topk_group = config.topk_group
406
 
407
  # topk selection algorithm
408
  self.norm_topk_prob = config.norm_topk_prob
409
  self.gating_dim = config.hidden_size
410
- self.weight = nn.Parameter(
411
- torch.empty((self.n_routed_experts, self.gating_dim))
412
- )
413
  self.reset_parameters()
414
 
415
  def reset_parameters(self) -> None:
416
- import torch.nn.init as init
417
-
418
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
419
-
420
  def forward(self, hidden_states):
421
- bsz, seq_len, h = hidden_states.shape
422
  ### compute gating score
423
  hidden_states = hidden_states.view(-1, h)
424
- logits = F.linear(
425
- hidden_states.type(torch.float32), self.weight.type(torch.float32), None
426
- )
427
- if self.scoring_func == "softmax":
428
- scores = logits.softmax(dim=-1, dtype=torch.float32)
429
  else:
430
- raise NotImplementedError(
431
- f"insupportable scoring function for MoE gating: {self.scoring_func}"
432
- )
433
-
434
  ### select top-k experts
435
- if self.topk_method == "greedy":
436
- topk_weight, topk_idx = torch.topk(
437
- scores, k=self.top_k, dim=-1, sorted=False
438
- )
439
- elif self.topk_method == "group_limited_greedy":
440
- group_scores = (
441
- scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
442
- ) # [n, n_group]
443
- group_idx = torch.topk(
444
- group_scores, k=self.topk_group, dim=-1, sorted=False
445
- )[
446
- 1
447
- ] # [n, top_k_group]
448
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
449
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
450
- score_mask = (
451
- group_mask.unsqueeze(-1)
452
- .expand(
453
- bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
454
- )
455
- .reshape(bsz * seq_len, -1)
456
- ) # [n, e]
457
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
458
- topk_weight, topk_idx = torch.topk(
459
- tmp_scores, k=self.top_k, dim=-1, sorted=False
460
- )
461
-
462
  ### norm gate to sum 1
463
  if self.top_k > 1 and self.norm_topk_prob:
464
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
465
  topk_weight = topk_weight / denominator
466
- else:
467
- topk_weight = topk_weight * self.routed_scaling_factor
468
  ### expert-level computation auxiliary loss
469
  if self.training and self.alpha > 0.0:
470
  scores_for_aux = scores
@@ -473,21 +324,11 @@ class MoEGate(nn.Module):
473
  topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
474
  if self.seq_aux:
475
  scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
476
- ce = torch.zeros(
477
- bsz, self.n_routed_experts, device=hidden_states.device
478
- )
479
- ce.scatter_add_(
480
- 1,
481
- topk_idx_for_aux_loss,
482
- torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
483
- ).div_(seq_len * aux_topk / self.n_routed_experts)
484
- aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
485
- dim=1
486
- ).mean() * self.alpha
487
  else:
488
- mask_ce = F.one_hot(
489
- topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
490
- )
491
  ce = mask_ce.float().mean(0)
492
  Pi = scores_for_aux.mean(0)
493
  fi = ce * self.n_routed_experts
@@ -499,10 +340,9 @@ class MoEGate(nn.Module):
499
 
500
  class AddAuxiliaryLoss(torch.autograd.Function):
501
  """
502
- The trick function of adding auxiliary (aux) loss,
503
  which includes the gradient of the aux loss during backpropagation.
504
  """
505
-
506
  @staticmethod
507
  def forward(ctx, x, loss):
508
  assert loss.numel() == 1
@@ -516,53 +356,22 @@ class AddAuxiliaryLoss(torch.autograd.Function):
516
  if ctx.required_aux_loss:
517
  grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
518
  return grad_output, grad_loss
519
-
520
-
521
- class DeepseekV2MoE(nn.Module):
522
  """
523
  A mixed expert module containing shared experts.
524
  """
525
-
526
  def __init__(self, config):
527
  super().__init__()
528
  self.config = config
529
  self.num_experts_per_tok = config.num_experts_per_tok
530
-
531
- if hasattr(config, "ep_size") and config.ep_size > 1:
532
- assert config.ep_size == dist.get_world_size()
533
- self.ep_size = config.ep_size
534
- self.experts_per_rank = config.n_routed_experts // config.ep_size
535
- self.ep_rank = dist.get_rank()
536
- self.experts = nn.ModuleList(
537
- [
538
- (
539
- DeepseekV2MLP(
540
- config, intermediate_size=config.moe_intermediate_size
541
- )
542
- if i >= self.ep_rank * self.experts_per_rank
543
- and i < (self.ep_rank + 1) * self.experts_per_rank
544
- else None
545
- )
546
- for i in range(config.n_routed_experts)
547
- ]
548
- )
549
- else:
550
- self.ep_size = 1
551
- self.experts_per_rank = config.n_routed_experts
552
- self.ep_rank = 0
553
- self.experts = nn.ModuleList(
554
- [
555
- DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)
556
- for i in range(config.n_routed_experts)
557
- ]
558
- )
559
  self.gate = MoEGate(config)
560
  if config.n_shared_experts is not None:
561
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
562
- self.shared_experts = DeepseekV2MLP(
563
- config=config, intermediate_size=intermediate_size
564
- )
565
-
566
  def forward(self, hidden_states):
567
  identity = hidden_states
568
  orig_shape = hidden_states.shape
@@ -570,96 +379,36 @@ class DeepseekV2MoE(nn.Module):
570
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
571
  flat_topk_idx = topk_idx.view(-1)
572
  if self.training:
573
- hidden_states = hidden_states.repeat_interleave(
574
- self.num_experts_per_tok, dim=0
575
- )
576
  y = torch.empty_like(hidden_states)
577
  for i, expert in enumerate(self.experts):
578
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
579
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
580
- y = y.view(*orig_shape)
581
  y = AddAuxiliaryLoss.apply(y, aux_loss)
582
  else:
583
- y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
584
  if self.config.n_shared_experts is not None:
585
  y = y + self.shared_experts(identity)
586
  return y
587
-
588
  @torch.no_grad()
589
- def moe_infer(self, x, topk_ids, topk_weight):
590
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
591
- cnts.scatter_(1, topk_ids, 1)
592
- tokens_per_expert = cnts.sum(dim=0)
593
- idxs = topk_ids.view(-1).argsort()
594
- sorted_tokens = x[idxs // topk_ids.shape[1]]
595
- sorted_tokens_shape = sorted_tokens.shape
596
- if self.ep_size > 1:
597
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
598
- tokens_per_expert_group = tokens_per_expert.new_empty(
599
- tokens_per_expert.shape[0]
600
- )
601
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
602
- output_splits = (
603
- tokens_per_expert_group.view(self.ep_size, -1)
604
- .sum(1)
605
- .cpu()
606
- .numpy()
607
- .tolist()
608
- )
609
- gathered_tokens = sorted_tokens.new_empty(
610
- tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
611
- )
612
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
613
- dist.all_to_all(
614
- list(gathered_tokens.split(output_splits)),
615
- list(sorted_tokens.split(input_split_sizes)),
616
- )
617
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
618
- self.ep_size, self.experts_per_rank
619
- ).sum(dim=0)
620
- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
621
- s = 0
622
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
623
- gatherd_idxs[s : s + k] = i % self.experts_per_rank
624
- s += k
625
- gatherd_idxs = gatherd_idxs.argsort()
626
- sorted_tokens = gathered_tokens[gatherd_idxs]
627
- tokens_per_expert = tokens_per_expert_post_gather
628
- tokens_per_expert = tokens_per_expert.cpu().numpy()
629
-
630
- outputs = []
631
- start_idx = 0
632
- for i, num_tokens in enumerate(tokens_per_expert):
633
- end_idx = start_idx + num_tokens
634
- if num_tokens == 0:
635
  continue
636
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
637
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
638
- expert_out = expert(tokens_for_this_expert)
639
- outputs.append(expert_out)
640
- start_idx = end_idx
641
-
642
- outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
643
- if self.ep_size > 1:
644
- new_x = torch.empty_like(outs)
645
- new_x[gatherd_idxs] = outs
646
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
647
- dist.all_to_all(
648
- list(gathered_tokens.split(input_split_sizes)),
649
- list(new_x.split(output_splits)),
650
- )
651
- outs = gathered_tokens
652
-
653
- new_x = torch.empty_like(outs)
654
- new_x[idxs] = outs
655
- final_out = (
656
- new_x.view(*topk_ids.shape, -1)
657
- .type(topk_weight.dtype)
658
- .mul_(topk_weight.unsqueeze(dim=-1))
659
- .sum(dim=1)
660
- .type(new_x.dtype)
661
- )
662
- return final_out
663
 
664
 
665
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -671,17 +420,15 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
671
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
672
  if n_rep == 1:
673
  return hidden_states
674
- hidden_states = hidden_states[:, :, None, :, :].expand(
675
- batch, num_key_value_heads, n_rep, slen, head_dim
676
- )
677
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
678
 
679
 
680
- # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
681
- class DeepseekV2Attention(nn.Module):
682
  """Multi-headed attention from 'Attention Is All You Need' paper"""
683
 
684
- def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
685
  super().__init__()
686
  self.config = config
687
  self.layer_idx = layer_idx
@@ -695,63 +442,29 @@ class DeepseekV2Attention(nn.Module):
695
  self.attention_dropout = config.attention_dropout
696
  self.hidden_size = config.hidden_size
697
  self.num_heads = config.num_attention_heads
698
-
 
 
699
  self.max_position_embeddings = config.max_position_embeddings
700
  self.rope_theta = config.rope_theta
701
- self.q_lora_rank = config.q_lora_rank
702
- self.qk_rope_head_dim = config.qk_rope_head_dim
703
- self.kv_lora_rank = config.kv_lora_rank
704
- self.v_head_dim = config.v_head_dim
705
- self.qk_nope_head_dim = config.qk_nope_head_dim
706
- self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
707
-
708
  self.is_causal = True
709
 
710
- if self.q_lora_rank is None:
711
- self.q_proj = nn.Linear(
712
- self.hidden_size, self.num_heads * self.q_head_dim, bias=False
713
- )
714
- else:
715
- self.q_a_proj = nn.Linear(
716
- self.hidden_size, config.q_lora_rank, bias=config.attention_bias
717
- )
718
- self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
719
- self.q_b_proj = nn.Linear(
720
- config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
721
  )
722
 
723
- self.kv_a_proj_with_mqa = nn.Linear(
724
- self.hidden_size,
725
- config.kv_lora_rank + config.qk_rope_head_dim,
726
- bias=config.attention_bias,
727
- )
728
- self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
729
- self.kv_b_proj = nn.Linear(
730
- config.kv_lora_rank,
731
- self.num_heads
732
- * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
733
- bias=False,
734
- )
735
-
736
- self.o_proj = nn.Linear(
737
- self.num_heads * self.v_head_dim,
738
- self.hidden_size,
739
- bias=config.attention_bias,
740
- )
741
  self._init_rope()
742
 
743
- self.softmax_scale = self.q_head_dim ** (-0.5)
744
- if self.config.rope_scaling is not None:
745
- mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
746
- scaling_factor = self.config.rope_scaling["factor"]
747
- if mscale_all_dim:
748
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
749
- self.softmax_scale = self.softmax_scale * mscale * mscale
750
-
751
  def _init_rope(self):
752
  if self.config.rope_scaling is None:
753
- self.rotary_emb = DeepseekV2RotaryEmbedding(
754
- self.qk_rope_head_dim,
755
  max_position_embeddings=self.max_position_embeddings,
756
  base=self.rope_theta,
757
  )
@@ -759,47 +472,24 @@ class DeepseekV2Attention(nn.Module):
759
  scaling_type = self.config.rope_scaling["type"]
760
  scaling_factor = self.config.rope_scaling["factor"]
761
  if scaling_type == "linear":
762
- self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
763
- self.qk_rope_head_dim,
764
  max_position_embeddings=self.max_position_embeddings,
765
  scaling_factor=scaling_factor,
766
  base=self.rope_theta,
767
  )
768
  elif scaling_type == "dynamic":
769
- self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
770
- self.qk_rope_head_dim,
771
  max_position_embeddings=self.max_position_embeddings,
772
  scaling_factor=scaling_factor,
773
  base=self.rope_theta,
774
  )
775
- elif scaling_type == "yarn":
776
- kwargs = {
777
- key: self.config.rope_scaling[key]
778
- for key in [
779
- "original_max_position_embeddings",
780
- "beta_fast",
781
- "beta_slow",
782
- "mscale",
783
- "mscale_all_dim",
784
- ]
785
- if key in self.config.rope_scaling
786
- }
787
- self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
788
- self.qk_rope_head_dim,
789
- max_position_embeddings=self.max_position_embeddings,
790
- scaling_factor=scaling_factor,
791
- base=self.rope_theta,
792
- **kwargs,
793
- )
794
  else:
795
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
796
 
797
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
798
- return (
799
- tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
800
- .transpose(1, 2)
801
- .contiguous()
802
- )
803
 
804
  def forward(
805
  self,
@@ -815,32 +505,36 @@ class DeepseekV2Attention(nn.Module):
815
  warnings.warn(
816
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
817
  )
 
818
  bsz, q_len, _ = hidden_states.size()
819
 
820
- if self.q_lora_rank is None:
821
- q = self.q_proj(hidden_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  else:
823
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
824
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
825
- q_nope, q_pe = torch.split(
826
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
827
- )
828
 
829
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
830
- compressed_kv, k_pe = torch.split(
831
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
832
- )
833
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
834
- kv = (
835
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
836
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
837
- .transpose(1, 2)
838
- )
839
 
840
- k_nope, value_states = torch.split(
841
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
842
- )
843
- kv_seq_len = value_states.shape[-2]
844
  if past_key_value is not None:
845
  if self.layer_idx is None:
846
  raise ValueError(
@@ -850,32 +544,23 @@ class DeepseekV2Attention(nn.Module):
850
  )
851
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
852
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
853
 
854
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
855
-
856
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
857
- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
858
- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
859
-
860
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
861
- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
862
- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
863
  if past_key_value is not None:
864
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
865
- key_states, value_states = past_key_value.update(
866
- key_states, value_states, self.layer_idx, cache_kwargs
867
- )
868
 
869
- attn_weights = (
870
- torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
871
- )
 
872
 
873
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
874
  raise ValueError(
875
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
876
  f" {attn_weights.size()}"
877
  )
878
- assert attention_mask is not None
879
  if attention_mask is not None:
880
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
881
  raise ValueError(
@@ -884,25 +569,26 @@ class DeepseekV2Attention(nn.Module):
884
  attn_weights = attn_weights + attention_mask
885
 
886
  # upcast attention to fp32
887
- attn_weights = nn.functional.softmax(
888
- attn_weights, dim=-1, dtype=torch.float32
889
- ).to(query_states.dtype)
890
- attn_weights = nn.functional.dropout(
891
- attn_weights, p=self.attention_dropout, training=self.training
892
- )
893
  attn_output = torch.matmul(attn_weights, value_states)
894
 
895
- if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
896
  raise ValueError(
897
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
898
  f" {attn_output.size()}"
899
  )
900
 
901
  attn_output = attn_output.transpose(1, 2).contiguous()
902
 
903
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
904
 
905
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
906
 
907
  if not output_attentions:
908
  attn_weights = None
@@ -910,10 +596,10 @@ class DeepseekV2Attention(nn.Module):
910
  return attn_output, attn_weights, past_key_value
911
 
912
 
913
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
914
- class DeepseekV2FlashAttention2(DeepseekV2Attention):
915
  """
916
- DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
917
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
918
  flash attention and deal with padding tokens in case the input contains any of them.
919
  """
@@ -936,7 +622,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
936
  use_cache: bool = False,
937
  **kwargs,
938
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
939
- # DeepseekV2FlashAttention2 attention does not support output_attentions
940
  if "padding_mask" in kwargs:
941
  warnings.warn(
942
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
@@ -949,57 +635,26 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
949
 
950
  bsz, q_len, _ = hidden_states.size()
951
 
952
- if self.q_lora_rank is None:
953
- q = self.q_proj(hidden_states)
954
- else:
955
- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
956
- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
957
- q_nope, q_pe = torch.split(
958
- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
959
- )
960
 
961
  # Flash attention requires the input to have the shape
962
  # batch_size x seq_length x head_dim x hidden_dim
963
  # therefore we just need to keep the original shape
964
- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
965
- compressed_kv, k_pe = torch.split(
966
- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
967
- )
968
- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
969
- kv = (
970
- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
971
- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
972
- .transpose(1, 2)
973
- )
974
-
975
- k_nope, value_states = torch.split(
976
- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
977
- )
978
- kv_seq_len = value_states.shape[-2]
979
 
980
- kv_seq_len = value_states.shape[-2]
981
  if past_key_value is not None:
982
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
983
-
984
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
985
- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
986
-
987
- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
988
- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
989
- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
990
-
991
- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
992
- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
993
- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
994
-
995
- if self.q_head_dim != self.v_head_dim:
996
- value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
997
 
998
  if past_key_value is not None:
999
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1000
- key_states, value_states = past_key_value.update(
1001
- key_states, value_states, self.layer_idx, cache_kwargs
1002
- )
1003
 
1004
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
1005
  # to be able to avoid many of these transpose/reshape/view.
@@ -1013,7 +668,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1013
  # therefore the input hidden states gets silently casted in float32. Hence, we need
1014
  # cast them back in the correct dtype just to be sure everything works as expected.
1015
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
1016
- # in fp32. (DeepseekV2RMSNorm handles it correctly)
1017
 
1018
  input_dtype = query_states.dtype
1019
  if input_dtype == torch.float32:
@@ -1023,7 +678,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1023
  elif torch.is_autocast_enabled():
1024
  target_dtype = torch.get_autocast_gpu_dtype()
1025
  else:
1026
- target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype
1027
 
1028
  logger.warning_once(
1029
  f"The input hidden states seems to be silently casted in float32, this might be related to"
@@ -1036,20 +691,10 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1036
  value_states = value_states.to(target_dtype)
1037
 
1038
  attn_output = self._flash_attention_forward(
1039
- query_states,
1040
- key_states,
1041
- value_states,
1042
- attention_mask,
1043
- q_len,
1044
- dropout=dropout_rate,
1045
- softmax_scale=self.softmax_scale,
1046
  )
1047
- if self.q_head_dim != self.v_head_dim:
1048
- attn_output = attn_output[:, :, :, : self.v_head_dim]
1049
 
1050
- attn_output = attn_output.reshape(
1051
- bsz, q_len, self.num_heads * self.v_head_dim
1052
- ).contiguous()
1053
  attn_output = self.o_proj(attn_output)
1054
 
1055
  if not output_attentions:
@@ -1058,14 +703,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1058
  return attn_output, attn_weights, past_key_value
1059
 
1060
  def _flash_attention_forward(
1061
- self,
1062
- query_states,
1063
- key_states,
1064
- value_states,
1065
- attention_mask,
1066
- query_length,
1067
- dropout=0.0,
1068
- softmax_scale=None,
1069
  ):
1070
  """
1071
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -1089,20 +727,13 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1089
  if not self._flash_attn_uses_top_left_mask:
1090
  causal = self.is_causal
1091
  else:
1092
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
1093
  causal = self.is_causal and query_length != 1
1094
 
1095
  # Contains at least one padding token in the sequence
1096
  if attention_mask is not None:
1097
  batch_size = query_states.shape[0]
1098
- (
1099
- query_states,
1100
- key_states,
1101
- value_states,
1102
- indices_q,
1103
- cu_seq_lens,
1104
- max_seq_lens,
1105
- ) = self._upad_input(
1106
  query_states, key_states, value_states, attention_mask, query_length
1107
  )
1108
 
@@ -1122,39 +753,27 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1122
  causal=causal,
1123
  )
1124
 
1125
- attn_output = pad_input(
1126
- attn_output_unpad, indices_q, batch_size, query_length
1127
- )
1128
  else:
1129
  attn_output = flash_attn_func(
1130
- query_states,
1131
- key_states,
1132
- value_states,
1133
- dropout,
1134
- softmax_scale=softmax_scale,
1135
- causal=causal,
1136
  )
1137
 
1138
  return attn_output
1139
 
1140
- def _upad_input(
1141
- self, query_layer, key_layer, value_layer, attention_mask, query_length
1142
- ):
1143
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
1144
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1145
 
1146
  key_layer = index_first_axis(
1147
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1148
- indices_k,
1149
  )
1150
  value_layer = index_first_axis(
1151
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
1152
- indices_k,
1153
  )
1154
  if query_length == kv_seq_len:
1155
  query_layer = index_first_axis(
1156
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
1157
- indices_k,
1158
  )
1159
  cu_seqlens_q = cu_seqlens_k
1160
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -1169,9 +788,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1169
  else:
1170
  # The -q_len: slice assumes left padding.
1171
  attention_mask = attention_mask[:, -query_length:]
1172
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1173
- query_layer, attention_mask
1174
- )
1175
 
1176
  return (
1177
  query_layer,
@@ -1183,36 +800,113 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1183
  )
1184
 
1185
 
1186
- ATTENTION_CLASSES = {
1187
- "eager": DeepseekV2Attention,
1188
- "flash_attention_2": DeepseekV2FlashAttention2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1189
  }
1190
 
1191
 
1192
- class DeepseekV2DecoderLayer(nn.Module):
1193
- def __init__(self, config: DeepseekV2Config, layer_idx: int):
1194
  super().__init__()
1195
  self.hidden_size = config.hidden_size
1196
 
1197
- self.self_attn = ATTENTION_CLASSES[config._attn_implementation](
1198
- config=config, layer_idx=layer_idx
1199
- )
1200
 
1201
- self.mlp = (
1202
- DeepseekV2MoE(config)
1203
- if (
1204
- config.n_routed_experts is not None
1205
- and layer_idx >= config.first_k_dense_replace
1206
- and layer_idx % config.moe_layer_freq == 0
1207
- )
1208
- else DeepseekV2MLP(config)
1209
- )
1210
- self.input_layernorm = DeepseekV2RMSNorm(
1211
- config.hidden_size, eps=config.rms_norm_eps
1212
- )
1213
- self.post_attention_layernorm = DeepseekV2RMSNorm(
1214
- config.hidden_size, eps=config.rms_norm_eps
1215
- )
1216
 
1217
  def forward(
1218
  self,
@@ -1223,9 +917,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1223
  output_attentions: Optional[bool] = False,
1224
  use_cache: Optional[bool] = False,
1225
  **kwargs,
1226
- ) -> Tuple[
1227
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
1228
- ]:
1229
  """
1230
  Args:
1231
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -1277,7 +969,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1277
  return outputs
1278
 
1279
 
1280
- DeepseekV2_START_DOCSTRING = r"""
1281
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1282
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1283
  etc.)
@@ -1287,7 +979,7 @@ DeepseekV2_START_DOCSTRING = r"""
1287
  and behavior.
1288
 
1289
  Parameters:
1290
- config ([`DeepseekV2Config`]):
1291
  Model configuration class with all the parameters of the model. Initializing with a config file does not
1292
  load the weights associated with the model, only the configuration. Check out the
1293
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -1295,16 +987,17 @@ DeepseekV2_START_DOCSTRING = r"""
1295
 
1296
 
1297
  @add_start_docstrings(
1298
- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1299
- DeepseekV2_START_DOCSTRING,
1300
  )
1301
- class DeepseekV2PreTrainedModel(PreTrainedModel):
1302
- config_class = DeepseekV2Config
1303
  base_model_prefix = "model"
1304
  supports_gradient_checkpointing = True
1305
- _no_split_modules = ["DeepseekV2DecoderLayer"]
1306
  _skip_keys_device_placement = "past_key_values"
1307
  _supports_flash_attn_2 = True
 
1308
  _supports_cache_class = True
1309
 
1310
  def _init_weights(self, module):
@@ -1319,7 +1012,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
1319
  module.weight.data[module.padding_idx].zero_()
1320
 
1321
 
1322
- DeepseekV2_INPUTS_DOCSTRING = r"""
1323
  Args:
1324
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1325
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -1390,33 +1083,29 @@ DeepseekV2_INPUTS_DOCSTRING = r"""
1390
 
1391
 
1392
  @add_start_docstrings(
1393
- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
1394
- DeepseekV2_START_DOCSTRING,
1395
  )
1396
- class DeepseekV2Model(DeepseekV2PreTrainedModel):
1397
  """
1398
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
1399
 
1400
  Args:
1401
- config: DeepseekV2Config
1402
  """
1403
 
1404
- def __init__(self, config: DeepseekV2Config):
1405
  super().__init__(config)
1406
  self.padding_idx = config.pad_token_id
1407
  self.vocab_size = config.vocab_size
1408
 
1409
- self.embed_tokens = nn.Embedding(
1410
- config.vocab_size, config.hidden_size, self.padding_idx
1411
- )
1412
  self.layers = nn.ModuleList(
1413
- [
1414
- DeepseekV2DecoderLayer(config, layer_idx)
1415
- for layer_idx in range(config.num_hidden_layers)
1416
- ]
1417
  )
 
1418
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1419
- self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1420
 
1421
  self.gradient_checkpointing = False
1422
  # Initialize weights and apply final processing
@@ -1428,7 +1117,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1428
  def set_input_embeddings(self, value):
1429
  self.embed_tokens = value
1430
 
1431
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1432
  def forward(
1433
  self,
1434
  input_ids: torch.LongTensor = None,
@@ -1441,27 +1130,17 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1441
  output_hidden_states: Optional[bool] = None,
1442
  return_dict: Optional[bool] = None,
1443
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1444
- output_attentions = (
1445
- output_attentions
1446
- if output_attentions is not None
1447
- else self.config.output_attentions
1448
- )
1449
  output_hidden_states = (
1450
- output_hidden_states
1451
- if output_hidden_states is not None
1452
- else self.config.output_hidden_states
1453
  )
1454
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1455
 
1456
- return_dict = (
1457
- return_dict if return_dict is not None else self.config.use_return_dict
1458
- )
1459
 
1460
  # retrieve input_ids and inputs_embeds
1461
  if input_ids is not None and inputs_embeds is not None:
1462
- raise ValueError(
1463
- "You cannot specify both input_ids and inputs_embeds at the same time"
1464
- )
1465
  elif input_ids is not None:
1466
  batch_size, seq_length = input_ids.shape[:2]
1467
  elif inputs_embeds is not None:
@@ -1486,10 +1165,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1486
  if position_ids is None:
1487
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1488
  position_ids = torch.arange(
1489
- past_key_values_length,
1490
- seq_length + past_key_values_length,
1491
- dtype=torch.long,
1492
- device=device,
1493
  )
1494
  position_ids = position_ids.unsqueeze(0)
1495
 
@@ -1498,19 +1174,21 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1498
 
1499
  if self._use_flash_attention_2:
1500
  # 2d mask is passed through the layers
1501
- attention_mask = (
1502
- attention_mask
1503
- if (attention_mask is not None and 0 in attention_mask)
1504
- else None
1505
- )
1506
- else:
1507
- # 4d mask is passed through the layers
1508
- attention_mask = _prepare_4d_causal_attention_mask(
1509
  attention_mask,
1510
  (batch_size, seq_length),
1511
  inputs_embeds,
1512
  past_key_values_length,
1513
  )
 
 
 
 
 
1514
 
1515
  # embed positions
1516
  hidden_states = inputs_embeds
@@ -1560,17 +1238,9 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1560
 
1561
  next_cache = None
1562
  if use_cache:
1563
- next_cache = (
1564
- next_decoder_cache.to_legacy_cache()
1565
- if use_legacy_cache
1566
- else next_decoder_cache
1567
- )
1568
  if not return_dict:
1569
- return tuple(
1570
- v
1571
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1572
- if v is not None
1573
- )
1574
  return BaseModelOutputWithPast(
1575
  last_hidden_state=hidden_states,
1576
  past_key_values=next_cache,
@@ -1579,12 +1249,12 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1579
  )
1580
 
1581
 
1582
- class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1583
  _tied_weights_keys = ["lm_head.weight"]
1584
 
1585
  def __init__(self, config):
1586
  super().__init__(config)
1587
- self.model = DeepseekV2Model(config)
1588
  self.vocab_size = config.vocab_size
1589
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1590
 
@@ -1609,10 +1279,8 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1609
  def get_decoder(self):
1610
  return self.model
1611
 
1612
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1613
- @replace_return_docstrings(
1614
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1615
- )
1616
  def forward(
1617
  self,
1618
  input_ids: torch.LongTensor = None,
@@ -1638,9 +1306,9 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1638
  Example:
1639
 
1640
  ```python
1641
- >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
1642
 
1643
- >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1644
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1645
 
1646
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
@@ -1651,19 +1319,11 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1651
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1652
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1653
  ```"""
1654
- output_attentions = (
1655
- output_attentions
1656
- if output_attentions is not None
1657
- else self.config.output_attentions
1658
- )
1659
  output_hidden_states = (
1660
- output_hidden_states
1661
- if output_hidden_states is not None
1662
- else self.config.output_hidden_states
1663
- )
1664
- return_dict = (
1665
- return_dict if return_dict is not None else self.config.use_return_dict
1666
  )
 
1667
 
1668
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1669
  outputs = self.model(
@@ -1679,7 +1339,12 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1679
  )
1680
 
1681
  hidden_states = outputs[0]
1682
- logits = self.lm_head(hidden_states)
 
 
 
 
 
1683
  logits = logits.float()
1684
 
1685
  loss = None
@@ -1708,12 +1373,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1708
  )
1709
 
1710
  def prepare_inputs_for_generation(
1711
- self,
1712
- input_ids,
1713
- past_key_values=None,
1714
- attention_mask=None,
1715
- inputs_embeds=None,
1716
- **kwargs,
1717
  ):
1718
  if past_key_values is not None:
1719
  if isinstance(past_key_values, Cache):
@@ -1728,10 +1388,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1728
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1729
  # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1730
  # input)
1731
- if (
1732
- attention_mask is not None
1733
- and attention_mask.shape[1] > input_ids.shape[1]
1734
- ):
1735
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1736
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1737
  # input_ids based on the past_length.
@@ -1776,19 +1433,16 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1776
  reordered_past = ()
1777
  for layer_past in past_key_values:
1778
  reordered_past += (
1779
- tuple(
1780
- past_state.index_select(0, beam_idx.to(past_state.device))
1781
- for past_state in layer_past
1782
- ),
1783
  )
1784
  return reordered_past
1785
 
1786
 
1787
  @add_start_docstrings(
1788
  """
1789
- The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
1790
 
1791
- [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1792
  (e.g. GPT-2) do.
1793
 
1794
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -1797,13 +1451,13 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1797
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1798
  each row of the batch).
1799
  """,
1800
- DeepseekV2_START_DOCSTRING,
1801
  )
1802
- class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1803
  def __init__(self, config):
1804
  super().__init__(config)
1805
  self.num_labels = config.num_labels
1806
- self.model = DeepseekV2Model(config)
1807
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1808
 
1809
  # Initialize weights and apply final processing
@@ -1815,7 +1469,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1815
  def set_input_embeddings(self, value):
1816
  self.model.embed_tokens = value
1817
 
1818
- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
1819
  def forward(
1820
  self,
1821
  input_ids: torch.LongTensor = None,
@@ -1835,9 +1489,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1835
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1836
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1837
  """
1838
- return_dict = (
1839
- return_dict if return_dict is not None else self.config.use_return_dict
1840
- )
1841
 
1842
  transformer_outputs = self.model(
1843
  input_ids,
@@ -1859,22 +1511,18 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1859
  batch_size = inputs_embeds.shape[0]
1860
 
1861
  if self.config.pad_token_id is None and batch_size != 1:
1862
- raise ValueError(
1863
- "Cannot handle batch sizes > 1 if no padding token is defined."
1864
- )
1865
  if self.config.pad_token_id is None:
1866
  sequence_lengths = -1
1867
  else:
1868
  if input_ids is not None:
1869
- sequence_lengths = (
1870
- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1871
- ).to(logits.device)
1872
  else:
1873
  sequence_lengths = -1
1874
 
1875
- pooled_logits = logits[
1876
- torch.arange(batch_size, device=logits.device), sequence_lengths
1877
- ]
1878
 
1879
  loss = None
1880
  if labels is not None:
@@ -1882,9 +1530,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1882
  if self.config.problem_type is None:
1883
  if self.num_labels == 1:
1884
  self.config.problem_type = "regression"
1885
- elif self.num_labels > 1 and (
1886
- labels.dtype == torch.long or labels.dtype == torch.int
1887
- ):
1888
  self.config.problem_type = "single_label_classification"
1889
  else:
1890
  self.config.problem_type = "multi_label_classification"
@@ -1897,9 +1543,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1897
  loss = loss_fct(pooled_logits, labels)
1898
  elif self.config.problem_type == "single_label_classification":
1899
  loss_fct = CrossEntropyLoss()
1900
- loss = loss_fct(
1901
- pooled_logits.view(-1, self.num_labels), labels.view(-1)
1902
- )
1903
  elif self.config.problem_type == "multi_label_classification":
1904
  loss_fct = BCEWithLogitsLoss()
1905
  loss = loss_fct(pooled_logits, labels)
@@ -1913,4 +1557,4 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
1913
  past_key_values=transformer_outputs.past_key_values,
1914
  hidden_states=transformer_outputs.hidden_states,
1915
  attentions=transformer_outputs.attentions,
1916
- )
 
5
  # and OPT implementations in this library. It has been modified from its
6
  # original forms to accommodate minor architectural differences compared
7
  # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
11
  # You may obtain a copy of the License at
 
34
  AttentionMaskConverter,
35
  _prepare_4d_attention_mask,
36
  _prepare_4d_causal_attention_mask,
37
+ _prepare_4d_causal_attention_mask_for_sdpa,
38
  )
39
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
 
 
 
 
40
  from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
 
 
 
42
  from transformers.utils import (
43
  add_start_docstrings,
44
  add_start_docstrings_to_model_forward,
 
48
  replace_return_docstrings,
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
+ from .configuration_deepseek import DeepseekConfig
52
+
 
53
 
54
  if is_flash_attn_2_available():
55
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
+ _CONFIG_FOR_DOC = "DeepseekConfig"
71
 
72
 
73
  def _get_unpad_data(attention_mask):
74
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
75
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
76
  max_seqlen_in_batch = seqlens_in_batch.max().item()
77
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
78
  return (
79
  indices,
80
  cu_seqlens,
 
82
  )
83
 
84
 
85
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
86
+ warnings.warn(
87
+ "Calling `transformers.models.Deepseek.modeling_Deepseek._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
88
+ )
89
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
90
+
91
+
92
+ def _make_causal_mask(
93
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
94
+ ):
95
+ warnings.warn(
96
+ "Calling `transformers.models.Deepseek.modeling_Deepseek._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.Deepseek.modeling_Deepseek.AttentionMaskConverter._make_causal_mask"
97
+ )
98
+ return AttentionMaskConverter._make_causal_mask(
99
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
100
+ )
101
+
102
+
103
+ class DeepseekRMSNorm(nn.Module):
104
  def __init__(self, hidden_size, eps=1e-6):
105
  """
106
+ DeepseekRMSNorm is equivalent to T5LayerNorm
107
  """
108
  super().__init__()
109
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
117
  return self.weight * hidden_states.to(input_dtype)
118
 
119
 
120
+ ALL_LAYERNORM_LAYERS.append(DeepseekRMSNorm)
121
 
122
 
123
+ class DeepseekRotaryEmbedding(nn.Module):
124
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
  super().__init__()
126
 
127
  self.dim = dim
128
  self.max_position_embeddings = max_position_embeddings
129
  self.base = base
130
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
131
  self.register_buffer("inv_freq", inv_freq, persistent=False)
132
 
133
  # Build here to make `torch.jit.trace` work.
134
  self._set_cos_sin_cache(
135
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
 
 
136
  )
137
  self.max_seq_len_cached = None
138
 
139
+
140
  def _set_cos_sin_cache(self, seq_len, device, dtype):
141
  self.max_seq_len_cached = seq_len
142
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
143
 
144
  freqs = torch.outer(t, self.inv_freq.to(t.device))
145
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
158
  )
159
 
160
 
161
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Deepseek
162
+ class DeepseekLinearScalingRotaryEmbedding(DeepseekRotaryEmbedding):
163
+ """DeepseekRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
164
 
165
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
166
  self.scaling_factor = scaling_factor
167
  super().__init__(dim, max_position_embeddings, base, device)
168
 
169
  def _set_cos_sin_cache(self, seq_len, device, dtype):
170
  self.max_seq_len_cached = seq_len
171
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
172
  t = t / self.scaling_factor
173
 
174
  freqs = torch.outer(t, self.inv_freq)
 
178
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
 
180
 
181
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
182
+ class DeepseekDynamicNTKScalingRotaryEmbedding(DeepseekRotaryEmbedding):
183
+ """DeepseekRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
184
 
185
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
 
 
 
 
 
 
 
186
  self.scaling_factor = scaling_factor
187
  super().__init__(dim, max_position_embeddings, base, device)
188
 
 
191
 
192
  if seq_len > self.max_position_embeddings:
193
  base = self.base * (
194
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
 
195
  ) ** (self.dim / (self.dim - 2))
196
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
 
 
197
  self.register_buffer("inv_freq", inv_freq, persistent=False)
198
 
199
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
 
200
 
201
  freqs = torch.outer(t, self.inv_freq)
202
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
205
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
206
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  # Copied from transformers.models.llama.modeling_llama.rotate_half
209
  def rotate_half(x):
210
  """Rotates half the hidden dims of the input."""
 
237
  """
238
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
239
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 
 
 
 
 
 
 
240
  q_embed = (q * cos) + (rotate_half(q) * sin)
241
  k_embed = (k * cos) + (rotate_half(k) * sin)
242
  return q_embed, k_embed
243
 
244
 
245
+ class DeepseekMLP(nn.Module):
246
+ def __init__(self, config, hidden_size = None, intermediate_size = None):
247
  super().__init__()
248
  self.config = config
249
  self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
250
+ self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
 
 
251
 
252
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
253
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
 
255
  self.act_fn = ACT2FN[config.hidden_act]
256
 
257
  def forward(self, x):
258
+ if self.config.pretraining_tp > 1:
259
+ slice = self.intermediate_size // self.config.pretraining_tp
260
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
261
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
262
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
263
+
264
+ gate_proj = torch.cat(
265
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
266
+ )
267
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
268
+
269
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
270
+ down_proj = [
271
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
272
+ ]
273
+ down_proj = sum(down_proj)
274
+ else:
275
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
276
+
277
  return down_proj
278
 
279
 
 
283
  self.config = config
284
  self.top_k = config.num_experts_per_tok
285
  self.n_routed_experts = config.n_routed_experts
286
+
287
  self.scoring_func = config.scoring_func
288
  self.alpha = config.aux_loss_alpha
289
  self.seq_aux = config.seq_aux
 
 
 
290
 
291
  # topk selection algorithm
292
  self.norm_topk_prob = config.norm_topk_prob
293
  self.gating_dim = config.hidden_size
294
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
 
 
295
  self.reset_parameters()
296
 
297
  def reset_parameters(self) -> None:
298
+ import torch.nn.init as init
 
299
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
300
+
301
  def forward(self, hidden_states):
302
+ bsz, seq_len, h = hidden_states.shape
303
  ### compute gating score
304
  hidden_states = hidden_states.view(-1, h)
305
+ logits = F.linear(hidden_states, self.weight, None)
306
+ if self.scoring_func == 'softmax':
307
+ scores = logits.softmax(dim=-1)
 
 
308
  else:
309
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
310
+
 
 
311
  ### select top-k experts
312
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
313
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  ### norm gate to sum 1
315
  if self.top_k > 1 and self.norm_topk_prob:
316
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
317
  topk_weight = topk_weight / denominator
318
+
 
319
  ### expert-level computation auxiliary loss
320
  if self.training and self.alpha > 0.0:
321
  scores_for_aux = scores
 
324
  topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
325
  if self.seq_aux:
326
  scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
327
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
328
+ ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
329
+ aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
 
 
 
 
 
 
 
 
330
  else:
331
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
 
 
332
  ce = mask_ce.float().mean(0)
333
  Pi = scores_for_aux.mean(0)
334
  fi = ce * self.n_routed_experts
 
340
 
341
  class AddAuxiliaryLoss(torch.autograd.Function):
342
  """
343
+ The trick function of adding auxiliary (aux) loss,
344
  which includes the gradient of the aux loss during backpropagation.
345
  """
 
346
  @staticmethod
347
  def forward(ctx, x, loss):
348
  assert loss.numel() == 1
 
356
  if ctx.required_aux_loss:
357
  grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
358
  return grad_output, grad_loss
359
+
360
+
361
+ class DeepseekMoE(nn.Module):
362
  """
363
  A mixed expert module containing shared experts.
364
  """
 
365
  def __init__(self, config):
366
  super().__init__()
367
  self.config = config
368
  self.num_experts_per_tok = config.num_experts_per_tok
369
+ self.experts = nn.ModuleList([DeepseekMLP(config, intermediate_size = config.moe_intermediate_size) for i in range(config.n_routed_experts)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  self.gate = MoEGate(config)
371
  if config.n_shared_experts is not None:
372
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
373
+ self.shared_experts = DeepseekMLP(config=config, intermediate_size = intermediate_size)
374
+
 
 
375
  def forward(self, hidden_states):
376
  identity = hidden_states
377
  orig_shape = hidden_states.shape
 
379
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
380
  flat_topk_idx = topk_idx.view(-1)
381
  if self.training:
382
+ hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
 
 
383
  y = torch.empty_like(hidden_states)
384
  for i, expert in enumerate(self.experts):
385
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
386
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
387
+ y = y.view(*orig_shape)
388
  y = AddAuxiliaryLoss.apply(y, aux_loss)
389
  else:
390
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
391
  if self.config.n_shared_experts is not None:
392
  y = y + self.shared_experts(identity)
393
  return y
394
+
395
  @torch.no_grad()
396
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
397
+ expert_cache = torch.zeros_like(x)
398
+ idxs = flat_expert_indices.argsort()
399
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
400
+ token_idxs = idxs // self.num_experts_per_tok
401
+ for i, end_idx in enumerate(tokens_per_expert):
402
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
403
+ if start_idx == end_idx:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  continue
405
+ expert = self.experts[i]
406
+ exp_token_idx = token_idxs[start_idx:end_idx]
407
+ expert_tokens = x[exp_token_idx]
408
+ expert_out = expert(expert_tokens)
409
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
410
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
411
+ return expert_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
 
414
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
420
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
421
  if n_rep == 1:
422
  return hidden_states
423
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
 
424
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
425
 
426
 
427
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Deepseek
428
+ class DeepseekAttention(nn.Module):
429
  """Multi-headed attention from 'Attention Is All You Need' paper"""
430
 
431
+ def __init__(self, config: DeepseekConfig, layer_idx: Optional[int] = None):
432
  super().__init__()
433
  self.config = config
434
  self.layer_idx = layer_idx
 
442
  self.attention_dropout = config.attention_dropout
443
  self.hidden_size = config.hidden_size
444
  self.num_heads = config.num_attention_heads
445
+ self.head_dim = self.hidden_size // self.num_heads
446
+ self.num_key_value_heads = config.num_key_value_heads
447
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
448
  self.max_position_embeddings = config.max_position_embeddings
449
  self.rope_theta = config.rope_theta
 
 
 
 
 
 
 
450
  self.is_causal = True
451
 
452
+ if (self.head_dim * self.num_heads) != self.hidden_size:
453
+ raise ValueError(
454
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
455
+ f" and `num_heads`: {self.num_heads})."
 
 
 
 
 
 
 
456
  )
457
 
458
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
459
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
460
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
461
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  self._init_rope()
463
 
 
 
 
 
 
 
 
 
464
  def _init_rope(self):
465
  if self.config.rope_scaling is None:
466
+ self.rotary_emb = DeepseekRotaryEmbedding(
467
+ self.head_dim,
468
  max_position_embeddings=self.max_position_embeddings,
469
  base=self.rope_theta,
470
  )
 
472
  scaling_type = self.config.rope_scaling["type"]
473
  scaling_factor = self.config.rope_scaling["factor"]
474
  if scaling_type == "linear":
475
+ self.rotary_emb = DeepseekLinearScalingRotaryEmbedding(
476
+ self.head_dim,
477
  max_position_embeddings=self.max_position_embeddings,
478
  scaling_factor=scaling_factor,
479
  base=self.rope_theta,
480
  )
481
  elif scaling_type == "dynamic":
482
+ self.rotary_emb = DeepseekDynamicNTKScalingRotaryEmbedding(
483
+ self.head_dim,
484
  max_position_embeddings=self.max_position_embeddings,
485
  scaling_factor=scaling_factor,
486
  base=self.rope_theta,
487
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  else:
489
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
490
 
491
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
492
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
 
 
 
493
 
494
  def forward(
495
  self,
 
505
  warnings.warn(
506
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
507
  )
508
+
509
  bsz, q_len, _ = hidden_states.size()
510
 
511
+ if self.config.pretraining_tp > 1:
512
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
513
+ query_slices = self.q_proj.weight.split(
514
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
515
+ )
516
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
517
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
518
+
519
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
520
+ query_states = torch.cat(query_states, dim=-1)
521
+
522
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
523
+ key_states = torch.cat(key_states, dim=-1)
524
+
525
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
526
+ value_states = torch.cat(value_states, dim=-1)
527
+
528
  else:
529
+ query_states = self.q_proj(hidden_states)
530
+ key_states = self.k_proj(hidden_states)
531
+ value_states = self.v_proj(hidden_states)
 
 
532
 
533
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
534
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
535
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
536
 
537
+ kv_seq_len = key_states.shape[-2]
 
 
 
538
  if past_key_value is not None:
539
  if self.layer_idx is None:
540
  raise ValueError(
 
544
  )
545
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
546
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
547
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
548
 
 
 
 
 
 
 
 
 
 
549
  if past_key_value is not None:
550
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
551
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
552
 
553
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
554
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
555
+
556
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
557
 
558
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
559
  raise ValueError(
560
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
561
  f" {attn_weights.size()}"
562
  )
563
+
564
  if attention_mask is not None:
565
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
566
  raise ValueError(
 
569
  attn_weights = attn_weights + attention_mask
570
 
571
  # upcast attention to fp32
572
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
573
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
 
 
 
574
  attn_output = torch.matmul(attn_weights, value_states)
575
 
576
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
577
  raise ValueError(
578
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
579
  f" {attn_output.size()}"
580
  )
581
 
582
  attn_output = attn_output.transpose(1, 2).contiguous()
583
 
584
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
585
 
586
+ if self.config.pretraining_tp > 1:
587
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
588
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
589
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
590
+ else:
591
+ attn_output = self.o_proj(attn_output)
592
 
593
  if not output_attentions:
594
  attn_weights = None
 
596
  return attn_output, attn_weights, past_key_value
597
 
598
 
599
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Deepseek
600
+ class DeepseekFlashAttention2(DeepseekAttention):
601
  """
602
+ Deepseek flash attention module. This module inherits from `DeepseekAttention` as the weights of the module stays
603
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
604
  flash attention and deal with padding tokens in case the input contains any of them.
605
  """
 
622
  use_cache: bool = False,
623
  **kwargs,
624
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
625
+ # DeepseekFlashAttention2 attention does not support output_attentions
626
  if "padding_mask" in kwargs:
627
  warnings.warn(
628
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
 
635
 
636
  bsz, q_len, _ = hidden_states.size()
637
 
638
+ query_states = self.q_proj(hidden_states)
639
+ key_states = self.k_proj(hidden_states)
640
+ value_states = self.v_proj(hidden_states)
 
 
 
 
 
641
 
642
  # Flash attention requires the input to have the shape
643
  # batch_size x seq_length x head_dim x hidden_dim
644
  # therefore we just need to keep the original shape
645
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
646
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
647
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
648
 
649
+ kv_seq_len = key_states.shape[-2]
650
  if past_key_value is not None:
651
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
652
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
653
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
 
 
 
 
 
 
 
 
 
 
654
 
655
  if past_key_value is not None:
656
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
657
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
658
 
659
  # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
660
  # to be able to avoid many of these transpose/reshape/view.
 
668
  # therefore the input hidden states gets silently casted in float32. Hence, we need
669
  # cast them back in the correct dtype just to be sure everything works as expected.
670
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
671
+ # in fp32. (DeepseekRMSNorm handles it correctly)
672
 
673
  input_dtype = query_states.dtype
674
  if input_dtype == torch.float32:
 
678
  elif torch.is_autocast_enabled():
679
  target_dtype = torch.get_autocast_gpu_dtype()
680
  else:
681
+ target_dtype = self.q_proj.weight.dtype
682
 
683
  logger.warning_once(
684
  f"The input hidden states seems to be silently casted in float32, this might be related to"
 
691
  value_states = value_states.to(target_dtype)
692
 
693
  attn_output = self._flash_attention_forward(
694
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
 
 
 
 
 
 
695
  )
 
 
696
 
697
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
 
 
698
  attn_output = self.o_proj(attn_output)
699
 
700
  if not output_attentions:
 
703
  return attn_output, attn_weights, past_key_value
704
 
705
  def _flash_attention_forward(
706
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
 
 
 
 
 
 
 
707
  ):
708
  """
709
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
 
727
  if not self._flash_attn_uses_top_left_mask:
728
  causal = self.is_causal
729
  else:
730
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekFlashAttention2 __init__.
731
  causal = self.is_causal and query_length != 1
732
 
733
  # Contains at least one padding token in the sequence
734
  if attention_mask is not None:
735
  batch_size = query_states.shape[0]
736
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
 
 
 
 
 
 
 
737
  query_states, key_states, value_states, attention_mask, query_length
738
  )
739
 
 
753
  causal=causal,
754
  )
755
 
756
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
757
  else:
758
  attn_output = flash_attn_func(
759
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
 
 
 
 
760
  )
761
 
762
  return attn_output
763
 
764
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
765
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
766
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
767
 
768
  key_layer = index_first_axis(
769
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
770
  )
771
  value_layer = index_first_axis(
772
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
773
  )
774
  if query_length == kv_seq_len:
775
  query_layer = index_first_axis(
776
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
777
  )
778
  cu_seqlens_q = cu_seqlens_k
779
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
788
  else:
789
  # The -q_len: slice assumes left padding.
790
  attention_mask = attention_mask[:, -query_length:]
791
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
792
 
793
  return (
794
  query_layer,
 
800
  )
801
 
802
 
803
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Deepseek
804
+ class DeepseekSdpaAttention(DeepseekAttention):
805
+ """
806
+ Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
807
+ `DeepseekAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
808
+ SDPA API.
809
+ """
810
+
811
+ # Adapted from DeepseekAttention.forward
812
+ def forward(
813
+ self,
814
+ hidden_states: torch.Tensor,
815
+ attention_mask: Optional[torch.Tensor] = None,
816
+ position_ids: Optional[torch.LongTensor] = None,
817
+ past_key_value: Optional[Cache] = None,
818
+ output_attentions: bool = False,
819
+ use_cache: bool = False,
820
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
821
+ if output_attentions:
822
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
823
+ logger.warning_once(
824
+ "DeepseekModel is using DeepseekSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
825
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
826
+ )
827
+ return super().forward(
828
+ hidden_states=hidden_states,
829
+ attention_mask=attention_mask,
830
+ position_ids=position_ids,
831
+ past_key_value=past_key_value,
832
+ output_attentions=output_attentions,
833
+ use_cache=use_cache,
834
+ )
835
+
836
+ bsz, q_len, _ = hidden_states.size()
837
+
838
+ query_states = self.q_proj(hidden_states)
839
+ key_states = self.k_proj(hidden_states)
840
+ value_states = self.v_proj(hidden_states)
841
+
842
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
843
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
844
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
845
+
846
+ kv_seq_len = key_states.shape[-2]
847
+ if past_key_value is not None:
848
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
849
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
850
+
851
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
852
+
853
+ if past_key_value is not None:
854
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
855
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
856
+
857
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
858
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
859
+
860
+ if attention_mask is not None:
861
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
862
+ raise ValueError(
863
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
864
+ )
865
+
866
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
867
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
868
+ if query_states.device.type == "cuda" and attention_mask is not None:
869
+ query_states = query_states.contiguous()
870
+ key_states = key_states.contiguous()
871
+ value_states = value_states.contiguous()
872
+
873
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
874
+ query_states,
875
+ key_states,
876
+ value_states,
877
+ attn_mask=attention_mask,
878
+ dropout_p=self.attention_dropout if self.training else 0.0,
879
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
880
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
881
+ )
882
+
883
+ attn_output = attn_output.transpose(1, 2).contiguous()
884
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
885
+
886
+ attn_output = self.o_proj(attn_output)
887
+
888
+ return attn_output, None, past_key_value
889
+
890
+
891
+ Deepseek_ATTENTION_CLASSES = {
892
+ "eager": DeepseekAttention,
893
+ "flash_attention_2": DeepseekFlashAttention2,
894
+ "sdpa": DeepseekSdpaAttention,
895
  }
896
 
897
 
898
+ class DeepseekDecoderLayer(nn.Module):
899
+ def __init__(self, config: DeepseekConfig, layer_idx: int):
900
  super().__init__()
901
  self.hidden_size = config.hidden_size
902
 
903
+ self.self_attn = Deepseek_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
 
 
904
 
905
+ self.mlp = DeepseekMoE(config) if (config.n_routed_experts is not None and \
906
+ layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0) \
907
+ else DeepseekMLP(config)
908
+ self.input_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
909
+ self.post_attention_layernorm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
 
 
 
 
 
910
 
911
  def forward(
912
  self,
 
917
  output_attentions: Optional[bool] = False,
918
  use_cache: Optional[bool] = False,
919
  **kwargs,
920
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
921
  """
922
  Args:
923
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
 
969
  return outputs
970
 
971
 
972
+ Deepseek_START_DOCSTRING = r"""
973
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
974
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
975
  etc.)
 
979
  and behavior.
980
 
981
  Parameters:
982
+ config ([`DeepseekConfig`]):
983
  Model configuration class with all the parameters of the model. Initializing with a config file does not
984
  load the weights associated with the model, only the configuration. Check out the
985
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
987
 
988
 
989
  @add_start_docstrings(
990
+ "The bare Deepseek Model outputting raw hidden-states without any specific head on top.",
991
+ Deepseek_START_DOCSTRING,
992
  )
993
+ class DeepseekPreTrainedModel(PreTrainedModel):
994
+ config_class = DeepseekConfig
995
  base_model_prefix = "model"
996
  supports_gradient_checkpointing = True
997
+ _no_split_modules = ["DeepseekDecoderLayer"]
998
  _skip_keys_device_placement = "past_key_values"
999
  _supports_flash_attn_2 = True
1000
+ _supports_sdpa = True
1001
  _supports_cache_class = True
1002
 
1003
  def _init_weights(self, module):
 
1012
  module.weight.data[module.padding_idx].zero_()
1013
 
1014
 
1015
+ Deepseek_INPUTS_DOCSTRING = r"""
1016
  Args:
1017
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1018
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
1083
 
1084
 
1085
  @add_start_docstrings(
1086
+ "The bare Deepseek Model outputting raw hidden-states without any specific head on top.",
1087
+ Deepseek_START_DOCSTRING,
1088
  )
1089
+ class DeepseekModel(DeepseekPreTrainedModel):
1090
  """
1091
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekDecoderLayer`]
1092
 
1093
  Args:
1094
+ config: DeepseekConfig
1095
  """
1096
 
1097
+ def __init__(self, config: DeepseekConfig):
1098
  super().__init__(config)
1099
  self.padding_idx = config.pad_token_id
1100
  self.vocab_size = config.vocab_size
1101
 
1102
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
1103
  self.layers = nn.ModuleList(
1104
+ [DeepseekDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
 
 
1105
  )
1106
+ self._use_sdpa = config._attn_implementation == "sdpa"
1107
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1108
+ self.norm = DeepseekRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1109
 
1110
  self.gradient_checkpointing = False
1111
  # Initialize weights and apply final processing
 
1117
  def set_input_embeddings(self, value):
1118
  self.embed_tokens = value
1119
 
1120
+ @add_start_docstrings_to_model_forward(Deepseek_INPUTS_DOCSTRING)
1121
  def forward(
1122
  self,
1123
  input_ids: torch.LongTensor = None,
 
1130
  output_hidden_states: Optional[bool] = None,
1131
  return_dict: Optional[bool] = None,
1132
  ) -> Union[Tuple, BaseModelOutputWithPast]:
1133
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1134
  output_hidden_states = (
1135
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
1136
  )
1137
  use_cache = use_cache if use_cache is not None else self.config.use_cache
1138
 
1139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1140
 
1141
  # retrieve input_ids and inputs_embeds
1142
  if input_ids is not None and inputs_embeds is not None:
1143
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
 
1144
  elif input_ids is not None:
1145
  batch_size, seq_length = input_ids.shape[:2]
1146
  elif inputs_embeds is not None:
 
1165
  if position_ids is None:
1166
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1167
  position_ids = torch.arange(
1168
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 
 
 
1169
  )
1170
  position_ids = position_ids.unsqueeze(0)
1171
 
 
1174
 
1175
  if self._use_flash_attention_2:
1176
  # 2d mask is passed through the layers
1177
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1178
+ elif self._use_sdpa and not output_attentions:
1179
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1180
+ # the manual implementation that requires a 4D causal mask in all cases.
1181
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 
 
 
1182
  attention_mask,
1183
  (batch_size, seq_length),
1184
  inputs_embeds,
1185
  past_key_values_length,
1186
  )
1187
+ else:
1188
+ # 4d mask is passed through the layers
1189
+ attention_mask = _prepare_4d_causal_attention_mask(
1190
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1191
+ )
1192
 
1193
  # embed positions
1194
  hidden_states = inputs_embeds
 
1238
 
1239
  next_cache = None
1240
  if use_cache:
1241
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
 
1242
  if not return_dict:
1243
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
 
 
 
1244
  return BaseModelOutputWithPast(
1245
  last_hidden_state=hidden_states,
1246
  past_key_values=next_cache,
 
1249
  )
1250
 
1251
 
1252
+ class DeepseekForCausalLM(DeepseekPreTrainedModel):
1253
  _tied_weights_keys = ["lm_head.weight"]
1254
 
1255
  def __init__(self, config):
1256
  super().__init__(config)
1257
+ self.model = DeepseekModel(config)
1258
  self.vocab_size = config.vocab_size
1259
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1260
 
 
1279
  def get_decoder(self):
1280
  return self.model
1281
 
1282
+ @add_start_docstrings_to_model_forward(Deepseek_INPUTS_DOCSTRING)
1283
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
 
1284
  def forward(
1285
  self,
1286
  input_ids: torch.LongTensor = None,
 
1306
  Example:
1307
 
1308
  ```python
1309
+ >>> from transformers import AutoTokenizer, DeepseekForCausalLM
1310
 
1311
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1312
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1313
 
1314
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
 
1319
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1320
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1321
  ```"""
1322
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
 
 
 
1323
  output_hidden_states = (
1324
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
 
 
 
1325
  )
1326
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1327
 
1328
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1329
  outputs = self.model(
 
1339
  )
1340
 
1341
  hidden_states = outputs[0]
1342
+ if self.config.pretraining_tp > 1:
1343
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1344
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1345
+ logits = torch.cat(logits, dim=-1)
1346
+ else:
1347
+ logits = self.lm_head(hidden_states)
1348
  logits = logits.float()
1349
 
1350
  loss = None
 
1373
  )
1374
 
1375
  def prepare_inputs_for_generation(
1376
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
 
 
 
 
1377
  ):
1378
  if past_key_values is not None:
1379
  if isinstance(past_key_values, Cache):
 
1388
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1389
  # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1390
  # input)
1391
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
 
 
1392
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1393
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1394
  # input_ids based on the past_length.
 
1433
  reordered_past = ()
1434
  for layer_past in past_key_values:
1435
  reordered_past += (
1436
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 
 
 
1437
  )
1438
  return reordered_past
1439
 
1440
 
1441
  @add_start_docstrings(
1442
  """
1443
+ The Deepseek Model transformer with a sequence classification head on top (linear layer).
1444
 
1445
+ [`DeepseekForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1446
  (e.g. GPT-2) do.
1447
 
1448
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
1451
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1452
  each row of the batch).
1453
  """,
1454
+ Deepseek_START_DOCSTRING,
1455
  )
1456
+ class DeepseekForSequenceClassification(DeepseekPreTrainedModel):
1457
  def __init__(self, config):
1458
  super().__init__(config)
1459
  self.num_labels = config.num_labels
1460
+ self.model = DeepseekModel(config)
1461
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1462
 
1463
  # Initialize weights and apply final processing
 
1469
  def set_input_embeddings(self, value):
1470
  self.model.embed_tokens = value
1471
 
1472
+ @add_start_docstrings_to_model_forward(Deepseek_INPUTS_DOCSTRING)
1473
  def forward(
1474
  self,
1475
  input_ids: torch.LongTensor = None,
 
1489
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1490
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1491
  """
1492
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
1493
 
1494
  transformer_outputs = self.model(
1495
  input_ids,
 
1511
  batch_size = inputs_embeds.shape[0]
1512
 
1513
  if self.config.pad_token_id is None and batch_size != 1:
1514
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
 
 
1515
  if self.config.pad_token_id is None:
1516
  sequence_lengths = -1
1517
  else:
1518
  if input_ids is not None:
1519
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1520
+ logits.device
1521
+ )
1522
  else:
1523
  sequence_lengths = -1
1524
 
1525
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
1526
 
1527
  loss = None
1528
  if labels is not None:
 
1530
  if self.config.problem_type is None:
1531
  if self.num_labels == 1:
1532
  self.config.problem_type = "regression"
1533
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
1534
  self.config.problem_type = "single_label_classification"
1535
  else:
1536
  self.config.problem_type = "multi_label_classification"
 
1543
  loss = loss_fct(pooled_logits, labels)
1544
  elif self.config.problem_type == "single_label_classification":
1545
  loss_fct = CrossEntropyLoss()
1546
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
 
 
1547
  elif self.config.problem_type == "multi_label_classification":
1548
  loss_fct = BCEWithLogitsLoss()
1549
  loss = loss_fct(pooled_logits, labels)
 
1557
  past_key_values=transformer_outputs.past_key_values,
1558
  hidden_states=transformer_outputs.hidden_states,
1559
  attentions=transformer_outputs.attentions,
1560
+ )