mjschock commited on
Commit
0c131b2
·
verified ·
1 Parent(s): 282c7ea

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +2 -2
  3. modeling_mamba.py +143 -388
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModelForCausalLM"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1bd3ca62665de4bfabff9d443f87a11090a10e505c0ccb56e6f9ca495b6e05bd
3
- size 671027808
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:699ed6f59fb948186f449c5031e0dc659d504c90d7e018302aa1e190cdb40220
3
+ size 516567560
modeling_mamba.py CHANGED
@@ -1,43 +1,38 @@
1
- import json
2
- import math
3
- import os
4
- from collections import namedtuple
5
- from dataclasses import dataclass
6
- from functools import partial
7
- from typing import Dict, Optional, Tuple, Union
8
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
- import transformers
13
  from einops import einsum, rearrange, repeat
14
- from torch import FloatTensor, Tensor, nn
15
- from transformers import GenerationMixin, PreTrainedModel
16
  from transformers.modeling_outputs import (
17
- BaseModelOutput,
18
  BaseModelOutputWithPast,
19
- CausalLMOutput,
20
- ImageClassifierOutput,
21
  QuestionAnsweringModelOutput,
22
  SequenceClassifierOutput,
23
  )
24
- from trl import PreTrainedModelWrapper
25
 
26
  from .configuration_mamba import MambaConfig
27
 
28
 
29
- # Inspired by:
30
- # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L31
31
- # - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L177
32
- # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L31
33
- class MambaBlock(nn.Module):
34
- def __init__(self, config: MambaConfig):
35
- """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].
36
- Furthermore, in section E.2.2 of the paper, the authors describe the Mamba block as:
37
- "[T]he Mamba block is simply the standard SwiGLU block with an extra conv → SSM path added."
38
- """
39
  super().__init__()
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
41
  self.config = config
42
 
43
  self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
@@ -62,9 +57,8 @@ class MambaBlock(nn.Module):
62
  A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner)
63
  self.A_log = nn.Parameter(torch.log(A))
64
  self.D = nn.Parameter(torch.ones(config.d_inner))
65
-
66
  self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
67
- # self.norm = RMSNorm(config.d_model)
68
 
69
  def forward(self, x):
70
  """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
@@ -80,9 +74,10 @@ class MambaBlock(nn.Module):
80
  mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
81
 
82
  """
 
83
  (b, l, d) = x.shape
84
- # x_copy = x # There was a separate class for residual, I deleted that part and added it here.
85
- # x = self.norm(x)
86
  x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
87
  (x, res) = x_and_res.split(
88
  split_size=[self.config.d_inner, self.config.d_inner], dim=-1
@@ -96,13 +91,9 @@ class MambaBlock(nn.Module):
96
 
97
  y = self.ssm(x)
98
 
99
- y = y * F.silu(
100
- res
101
- ) # SwiGLU: Swish_β(xW + b) ⊗ (xV + c) => torch.kron(F.silu(xW + b), xV + c) => torch.kron(F.silu(res), y)
102
 
103
- output = self.out_proj(y) # output = self.out_proj(y) + x_copy
104
-
105
- # "the Mamba block is simply the standard SwiGLU block with an extra 𝖼𝗈𝗇𝗏 → 𝖲𝖲𝖬 path added"
106
 
107
  return output
108
 
@@ -177,21 +168,17 @@ class MambaBlock(nn.Module):
177
  # Discretize continuous parameters (A, B)
178
  # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
179
  # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
180
- # "A is the more important term and the performance doesn't change much with the simplification on B"
181
- deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
182
- deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n")
183
 
184
  # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
185
- # Note that the below is sequential, while the official implementation does a much faster parallel scan that
186
- # is additionally hardware-aware (like FlashAttention).
187
  x = torch.zeros((b, d_in, n), device=deltaA.device)
188
  ys = []
189
-
190
  for i in range(l):
191
- x = deltaA[:, i] * x + deltaB_u[:, i]
192
  y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
193
  ys.append(y)
194
-
195
  y = torch.stack(ys, dim=1) # shape (b, l, d_in)
196
 
197
  y = y + u * D
@@ -199,395 +186,163 @@ class MambaBlock(nn.Module):
199
  return y
200
 
201
 
202
- # Inspired by:
203
- # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L19
204
- # - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
205
- # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/ops/triton/layernorm.py#L481
206
- class RMSNorm(nn.Module):
207
- def __init__(self, d_model: int, eps: float = 1e-5):
208
- super().__init__()
209
-
210
- self.eps = eps
211
- self.weight = nn.Parameter(torch.ones(d_model))
212
-
213
- def forward(self, x):
214
- output = (
215
- x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
216
- )
217
-
218
- return output
219
-
220
-
221
- class ResidualBlock(
222
- nn.Module
223
- ): # Copied and modified from https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L143
224
  def __init__(self, config: MambaConfig):
225
- """Simple block wrapping Mamba block with normalization and residual connection."""
226
  super().__init__()
 
227
 
228
- # self.args = args
229
- self.mixer = MambaBlock(config)
230
- self.norm = RMSNorm(config.d_model)
231
- # self.norm = partial(
232
- # nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.norm_epsilon,
233
- # )
234
-
235
- # if config.rms_norm:
236
- # self.norm = RMSNorm(config.d_model, eps=config.norm_epsilon)
237
-
238
- # else:
239
- # self.norm = nn.LayerNorm(config.d_model, eps=config.norm_epsilon)
240
-
241
- def forward(self, x):
242
- """
243
- Args:
244
- x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
245
-
246
- Returns:
247
- output: shape (b, l, d)
248
-
249
- Official Implementation:
250
- Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
251
-
252
- Note: the official repo chains residual blocks that look like
253
- [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
254
- where the first Add is a no-op. This is purely for performance reasons as this
255
- allows them to fuse the Add->Norm.
256
-
257
- We instead implement our blocks as the more familiar, simpler, and numerically equivalent
258
- [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
259
-
260
- """
261
- output = self.mixer(self.norm(x)) + x
262
-
263
- return output
264
-
265
-
266
- # Inspired by:
267
- # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L181
268
- # class MambaPretrainedModel(PreTrainedModel, nn.Module):
269
- class MambaPretrainedModel(PreTrainedModel):
270
- r"""
271
- Base class for all models.
272
-
273
- [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
274
- downloading and saving models as well as a few methods common to all models to:
275
-
276
- - resize the input embeddings,
277
- - prune heads in the self-attention heads.
278
-
279
- Class attributes (overridden by derived classes):
280
-
281
- - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
282
- for this model architecture.
283
- - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
284
- taking as arguments:
285
 
286
- - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
287
- - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
288
- - **path** (`str`) -- A path to the TensorFlow checkpoint.
289
 
290
- - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
291
- classes of the same architecture adding modules on top of the base model.
292
- - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
293
- - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
294
- models, `pixel_values` for vision models and `input_values` for speech models).
295
- """
296
 
297
- config_class = MambaConfig # TODO: Build on top of MambaConfig?
298
- # base_model_prefix = "backbone"
299
- base_model_prefix = "mamba"
300
- main_input_name = "input_ids"
301
- model_tags = None
302
 
303
- _auto_class = None
304
- _no_split_modules = ["MambaBlock"]
305
- _skip_keys_device_placement = None
306
- _keep_in_fp32_modules = None
307
-
308
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
309
- # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
310
- _keys_to_ignore_on_load_missing = None
311
- # a list of `re` patterns of `state_dict` keys that should be removed from the list of
312
- # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
313
- # warnings.
314
- _keys_to_ignore_on_load_unexpected = None
315
- # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
316
- # trained, but which are either deterministic or tied variables)
317
- _keys_to_ignore_on_save = None
318
- # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
319
- _tied_weights_keys = None
320
-
321
- is_parallelizable = False
322
  supports_gradient_checkpointing = True
 
323
 
324
- # Flash Attention 2 support
325
- _supports_flash_attn_2 = False
326
-
327
- # SDPA support
328
- _supports_sdpa = False
329
-
330
- # Has support for a `Cache` instance as `past_key_values`
331
- _supports_cache_class = False
332
-
333
- def __init__(self, *inputs, **kwargs):
334
- super().__init__(*inputs, **kwargs)
335
-
336
- # https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L54
337
- def _init_weights(
338
- self,
339
- module,
340
- initializer_range=0.02, # Now only used for embedding layer.
341
- rescale_prenorm_residual=True,
342
- n_residuals_per_layer=1, # Change to 2 if we have MLP
343
- ):
344
- if isinstance(module, nn.Linear):
345
  if module.bias is not None:
346
- if not getattr(module.bias, "_no_reinit", False):
347
- nn.init.zeros_(module.bias)
348
-
349
  elif isinstance(module, nn.Embedding):
350
- nn.init.normal_(module.weight, std=initializer_range)
351
-
352
- if rescale_prenorm_residual:
353
- # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
354
- # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
355
- # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
356
- # > -- GPT-2 :: https://openai.com/blog/better-language-models/
357
- #
358
- # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
359
- for name, p in module.named_parameters():
360
- if name in [
361
- "out_proj.weight",
362
- "fc2.weight",
363
- ]:
364
- # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
365
- # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
366
- # We need to reinit p since this code could be called multiple times
367
- # Having just p *= scale would repeatedly scale it down
368
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
369
- with torch.no_grad():
370
- p /= math.sqrt(n_residuals_per_layer * self.config.n_layer)
371
-
372
- # def _set_gradient_checkpointing(self, module, value=False):
373
- # if isinstance(module, GPT2Model):
374
- # module.gradient_checkpointing = value
375
-
376
- # Inspired by:
377
- # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L86
378
- class MambaModel(MambaPretrainedModel):
379
- def __init__(self, config: MambaConfig = MambaConfig(), **kwargs) -> None:
380
  """Full Mamba model.
381
  Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`]
 
382
  Args:
383
  config: MambaConfig
384
  """
385
- super().__init__(
386
- config,
387
- **kwargs,
388
- )
389
-
390
- self.embedding = nn.Embedding(
391
- num_embeddings=self.config.vocab_size,
392
- embedding_dim=self.config.d_model,
393
- )
394
-
395
- self.layers = nn.ModuleList(
396
- [ResidualBlock(self.config) for _ in range(self.config.n_layer)]
397
- )
398
- self.norm_f = RMSNorm(d_model=self.config.d_model)
399
- # self.norm_f = (nn.LayerNorm if not self.config.rms_norm else RMSNorm)(
400
- # # self.config.d_model, eps=self.config.norm_epsilon, **factory_kwargs
401
- # self.config.d_model, eps=self.config.norm_epsilon,
402
- # )
403
 
404
- # self.gradient_checkpointing = False
 
 
405
 
406
- # Initialize weights and apply final processing
407
  self.post_init()
408
 
409
- # def _init_weights(self, module):
410
- # std = 0.02
411
 
412
- # if isinstance(module, (nn.Linear, nn.Conv1d)):
413
- # module.weight.data.normal_(mean=0.0, std=std)
414
-
415
- # if module.bias is not None:
416
- # module.bias.data.zero_()
417
-
418
- # elif isinstance(module, nn.Embedding):
419
- # module.weight.data.normal_(mean=0.0, std=std)
420
-
421
- # if module.padding_idx is not None:
422
- # module.weight.data[module.padding_idx].zero_()
423
-
424
- # def get_input_embeddings(self):
425
- # return self.embed_out
426
-
427
- # def set_input_embeddings(self, value):
428
- # self.embed_out = value
429
 
430
  def forward(
431
  self,
432
  input_ids: torch.LongTensor = None,
433
- output_hidden_states=False,
434
  return_dict: Optional[bool] = None,
435
- **kwargs,
436
- # ) -> BaseModelOutput:
437
  ) -> Union[Tuple, BaseModelOutputWithPast]:
438
- batch_size = input_ids.shape[0]
439
- hidden_size = self.config.d_model
440
- hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
441
- sequence_length = input_ids.shape[1]
442
- output_hidden_states = output_hidden_states or self.config.output_hidden_states
443
-
444
- last_hidden_state = self.embedding(input_ids)
445
- assert last_hidden_state.shape == (
446
- batch_size,
447
- sequence_length,
448
- hidden_size,
449
- ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
450
- hidden_states += (last_hidden_state,)
451
-
452
  for layer in self.layers:
453
- last_hidden_state = layer(last_hidden_state)
454
- assert last_hidden_state.shape == (
455
- batch_size,
456
- sequence_length,
457
- hidden_size,
458
- ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
459
- hidden_states += (last_hidden_state,)
460
-
461
- last_hidden_state = self.norm_f(last_hidden_state)
462
- assert last_hidden_state.shape == (
463
- batch_size,
464
- sequence_length,
465
- hidden_size,
466
- ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
467
- hidden_states += (last_hidden_state,)
468
-
469
- assert (
470
- len(hidden_states) == self.config.n_layer + 2
471
- ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
472
-
473
- # return BaseModelOutput(
474
  return BaseModelOutputWithPast(
475
- hidden_states=hidden_states if output_hidden_states else None,
476
- last_hidden_state=last_hidden_state,
477
  )
478
 
479
 
480
- # Influences:
481
- # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L238
482
- # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L176
483
- # class MambaModelForCausalLM(MambaModel, GenerationMixin):
484
- # class MambaModelForCausalLM(PreTrainedModel, GenerationMixin):
485
- # class MambaLMHeadModel(MambaPretrainedModel, GenerationMixin):
486
- class MambaLMHeadModel(MambaPretrainedModel):
487
- _tied_weights_keys = [
488
- "backbone.embedding.weight",
489
- "lm_head.weight",
490
- ]
491
 
492
- def __init__(
493
- self,
494
- config: MambaConfig = MambaConfig(),
495
- **kwargs,
496
- ) -> None:
497
- super().__init__(
498
- config,
499
- **kwargs,
500
- )
501
 
502
- self.backbone = MambaModel(
503
- config=self.config,
504
- )
505
 
506
- self.lm_head = nn.Linear(
507
- in_features=self.config.d_model,
508
- out_features=self.config.vocab_size,
509
- bias=False,
510
- )
511
 
512
- # # self.head.weight = self.backbone.embedding.weight # TODO: there's some logic in GenerationMix that does this
 
513
 
514
- # Initialize weights and apply final processing
515
- self.post_init()
516
 
517
- def forward(
518
- self, input_ids, output_hidden_states=False, **kwargs
519
- ) -> CausalLMOutput:
520
- batch_size = input_ids.shape[0]
521
- sequence_length = input_ids.shape[1]
522
- vocab_size = self.config.vocab_size
523
- output_hidden_states = output_hidden_states or self.config.output_hidden_states
524
 
 
 
 
 
 
 
 
 
 
 
 
525
  outputs = self.backbone(
526
  input_ids=input_ids,
527
- output_hidden_states=output_hidden_states,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  )
529
 
530
- last_hidden_state = outputs.last_hidden_state
 
 
531
 
532
- logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
533
- self.lm_head(
534
- last_hidden_state,
535
- )
536
- )
537
 
538
- return CausalLMOutput(
539
- hidden_states=outputs.hidden_states if output_hidden_states else None,
540
- logits=logits,
541
- )
 
 
542
 
543
- def prepare_inputs_for_generation(
544
- self, input_ids, attention_mask=None, **model_kwargs
545
- ):
546
- return {
547
- "input_ids": input_ids,
548
- }
549
-
550
-
551
- # class MultimodalMambaModelForCausalLMWithValueHead(PreTrainedModelWrapper):
552
- # lm_head_namings: Tuple[str, str] = ("lm_head", "embed_out")
553
- # transformers_parent_class: transformers.PreTrainedModel = transformers.AutoModelForCausalLM
554
-
555
- # # def __init__(
556
- # # self,
557
- # # config: MultimodalMambaConfig = MultimodalMambaConfig(),
558
- # # **kwargs,
559
- # # ) -> None:
560
- # # super().__init__(
561
- # # config,
562
- # # **kwargs,
563
- # # )
564
-
565
- # # self.model = MultimodalMambaModelForCausalLM(
566
- # # config=config,
567
- # # )
568
-
569
- # # self.value_head = nn.Linear(
570
- # # in_features=config.embedding_dim,
571
- # # out_features=1,
572
- # # bias=False,
573
- # # )
574
-
575
- # # def forward(
576
- # # self, input_ids, output_hidden_states=False, **kwargs
577
- # # ) -> CausalLMOutput:
578
- # # outputs = self.model(
579
- # # input_ids=input_ids,
580
- # # output_hidden_states=output_hidden_states,
581
- # # )
582
-
583
- # # last_hidden_state = outputs.last_hidden_state
584
-
585
- # # value: torch.FloatTensor[batch_size, sequence_length, 1] = self.value_head(
586
- # # last_hidden_state,
587
- # # )
588
-
589
- # # return CausalLMOutput(
590
- # # hidden_states=outputs.hidden_states if output_hidden_states else None,
591
- # # logits=outputs.logits,
592
- # # value=value,
593
- # # )
 
1
+ from typing import Optional, Tuple, Union
 
 
 
 
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
6
  from einops import einsum, rearrange, repeat
7
+ from torch.nn import CrossEntropyLoss
 
8
  from transformers.modeling_outputs import (
 
9
  BaseModelOutputWithPast,
10
+ CausalLMOutputWithPast,
 
11
  QuestionAnsweringModelOutput,
12
  SequenceClassifierOutput,
13
  )
14
+ from transformers.modeling_utils import PreTrainedModel
15
 
16
  from .configuration_mamba import MambaConfig
17
 
18
 
19
+ class MambaRMSNorm(nn.Module):
20
+ def __init__(self, d_model: int, eps: float = 1e-5):
 
 
 
 
 
 
 
 
21
  super().__init__()
22
+ self.eps = eps
23
+ self.weight = nn.Parameter(torch.ones(d_model))
24
+
25
+ def forward(self, x):
26
+ output = (
27
+ x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
28
+ )
29
+ return output
30
 
31
+
32
+ class Mamba(nn.Module):
33
+ def __init__(self, config: MambaConfig):
34
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
35
+ super().__init__()
36
  self.config = config
37
 
38
  self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
 
57
  A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner)
58
  self.A_log = nn.Parameter(torch.log(A))
59
  self.D = nn.Parameter(torch.ones(config.d_inner))
 
60
  self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
61
+ # self.norm = MambaRMSNorm(config.d_model)
62
 
63
  def forward(self, x):
64
  """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
 
74
  mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
75
 
76
  """
77
+
78
  (b, l, d) = x.shape
79
+ x_copy = x # There was a separate class for residual, I deleted that part and added it here.
80
+ x = self.norm(x)
81
  x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
82
  (x, res) = x_and_res.split(
83
  split_size=[self.config.d_inner, self.config.d_inner], dim=-1
 
91
 
92
  y = self.ssm(x)
93
 
94
+ y = y * F.silu(res)
 
 
95
 
96
+ output = self.out_proj(y) + x_copy
 
 
97
 
98
  return output
99
 
 
168
  # Discretize continuous parameters (A, B)
169
  # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
170
  # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
171
+ # "A is the more important term and the performance doesn't change much with the simplication on B"
172
+ deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b d_in l n"))
173
+ deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b d_in l n")
174
 
175
  # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
 
 
176
  x = torch.zeros((b, d_in, n), device=deltaA.device)
177
  ys = []
 
178
  for i in range(l):
179
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
180
  y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
181
  ys.append(y)
 
182
  y = torch.stack(ys, dim=1) # shape (b, l, d_in)
183
 
184
  y = y + u * D
 
186
  return y
187
 
188
 
189
+ class Block(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def __init__(self, config: MambaConfig):
191
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
192
  super().__init__()
193
+ self.config = config
194
 
195
+ self.mixer = Mamba(config)
196
+ self.norm = MambaRMSNorm(config.d_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
 
 
 
198
 
199
+ class MambaBlock(Block):
200
+ pass
 
 
 
 
201
 
 
 
 
 
 
202
 
203
+ class MambaPreTrainedModel(PreTrainedModel):
204
+ config_class = MambaConfig
205
+ base_model_prefix = "backbone"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  supports_gradient_checkpointing = True
207
+ _no_split_modules = ["MambaBlock"]
208
 
209
+ def _init_weights(self, module):
210
+ std = 0.02
211
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
212
+ module.weight.data.normal_(mean=0.0, std=std)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  if module.bias is not None:
214
+ module.bias.data.zero_()
 
 
215
  elif isinstance(module, nn.Embedding):
216
+ module.weight.data.normal_(mean=0.0, std=std)
217
+ if module.padding_idx is not None:
218
+ module.weight.data[module.padding_idx].zero_()
219
+
220
+
221
+ class MambaModel(MambaPreTrainedModel):
222
+ def __init__(self, config: MambaConfig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  """Full Mamba model.
224
  Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`]
225
+
226
  Args:
227
  config: MambaConfig
228
  """
229
+ super().__init__(config)
230
+ self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
233
+ self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
234
+ self.norm_f = MambaRMSNorm(config.d_model)
235
 
236
+ self.gradient_checkpointing = False
237
  self.post_init()
238
 
239
+ def get_input_embeddings(self):
240
+ return self.embedding
241
 
242
+ def set_input_embeddings(self, value):
243
+ self.embedding = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  def forward(
246
  self,
247
  input_ids: torch.LongTensor = None,
 
248
  return_dict: Optional[bool] = None,
 
 
249
  ) -> Union[Tuple, BaseModelOutputWithPast]:
250
+ x = self.embedding(input_ids)
251
+ all_hidden_states = list()
 
 
 
 
 
 
 
 
 
 
 
 
252
  for layer in self.layers:
253
+ x = layer(x)
254
+ all_hidden_states.append(x)
255
+
256
+ hidden_states = self.norm_f(x)
257
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  return BaseModelOutputWithPast(
259
+ last_hidden_state=hidden_states,
260
+ hidden_states=all_hidden_states,
261
  )
262
 
263
 
264
+ class MambaModelForCausalLM(MambaPreTrainedModel):
265
+ _tied_weights_keys = ["lm_head.weight"]
 
 
 
 
 
 
 
 
 
266
 
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.backbone = MambaModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
272
+ self.lm_head.weight = self.backbone.embedding.weight
273
+ self.post_init()
 
 
274
 
275
+ # def get_input_embeddings(self):
276
+ # return self.model.embedding
 
277
 
278
+ # def set_input_embeddings(self, value):
279
+ # self.model.embedding = value
 
 
 
280
 
281
+ # def get_output_embeddings(self):
282
+ # return self.lm_head
283
 
284
+ # def set_output_embeddings(self, new_embeddings):
285
+ # self.lm_head = new_embeddings
286
 
287
+ # def set_decoder(self, decoder):
288
+ # self.model = decoder
 
 
 
 
 
289
 
290
+ # def get_decoder(self):
291
+ # return self.model
292
+
293
+ def forward(
294
+ self,
295
+ input_ids: torch.LongTensor = None,
296
+ labels: Optional[torch.LongTensor] = None,
297
+ output_attentions: Optional[bool] = None,
298
+ output_hidden_states: Optional[bool] = None,
299
+ return_dict: Optional[bool] = None,
300
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
301
  outputs = self.backbone(
302
  input_ids=input_ids,
303
+ return_dict=return_dict,
304
+ )
305
+ hidden_states = outputs[0]
306
+ logits = self.lm_head(hidden_states)
307
+ logits = logits.float()
308
+ loss = None
309
+
310
+ if labels is not None:
311
+ shift_logits = logits[..., :-1, :].contiguous()
312
+ shift_labels = labels[..., 1:].contiguous()
313
+ loss_fct = CrossEntropyLoss()
314
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
315
+ shift_labels = shift_labels.view(-1)
316
+
317
+ shift_labels = shift_labels.to(shift_logits.device)
318
+ loss = loss_fct(shift_logits, shift_labels)
319
+
320
+ if not return_dict:
321
+ output = (logits,) + outputs[1:]
322
+ return (loss,) + output if loss is not None else output
323
+
324
+ return CausalLMOutputWithPast(
325
+ loss=loss,
326
+ logits=logits,
327
+ hidden_states=outputs.hidden_states,
328
  )
329
 
330
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
331
+ model_inputs = {"input_ids": input_ids}
332
+ return model_inputs
333
 
 
 
 
 
 
334
 
335
+ class MambaModelForSequenceClassification(MambaPreTrainedModel):
336
+ def __init__(self, config):
337
+ super().__init__(config)
338
+ self.model = MambaModel(config)
339
+ # self.classifier = nn.Linear(config.d_model, config.num_labels)
340
+ # self.post_init()
341
 
342
+ def forward(
343
+ self,
344
+ input_ids: Optional[torch.Tensor] = None,
345
+ labels: Optional[torch.Tensor] = None,
346
+ **kwargs,
347
+ ) -> SequenceClassifierOutput:
348
+ pass