mjschock commited on
Commit
70327f3
·
verified ·
1 Parent(s): 358373d

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +2 -2
  3. modeling_mamba.py +596 -61
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaLMHeadModel"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaLMHeadModel"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6504b24e9ba95e4a6bad94a346c849040623647d1a99a47f4f5e1cd32cbd9572
3
- size 259551392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bd3ca62665de4bfabff9d443f87a11090a10e505c0ccb56e6f9ca495b6e05bd
3
+ size 671027808
modeling_mamba.py CHANGED
@@ -1,82 +1,617 @@
1
- from typing import Optional, Tuple
 
 
 
 
 
 
2
 
3
- from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
4
  import torch
 
 
 
 
 
5
  from transformers import GenerationMixin, PreTrainedModel
6
- from transformers.generation import TextStreamer
 
 
 
 
 
 
 
 
7
 
8
  from .configuration_mamba import MambaConfig
9
 
10
- class MambaModel(PreTrainedModel):
11
- config_class = MambaConfig
12
 
13
- def __init__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self,
15
- config,
16
- initializer_cfg=None,
17
- device=None,
18
- dtype=None,
19
- **kwargs,
20
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  super().__init__(
22
  config,
23
  **kwargs,
24
  )
25
 
26
- self.model = MambaLMHeadModel(
27
- config,
28
- initializer_cfg=initializer_cfg,
29
- device=device,
30
- dtype=dtype,
 
 
 
 
31
  )
32
 
33
- def forward(
34
- self,
35
- input_ids,
36
- position_ids=None,
37
- inference_params=None,
38
- num_last_tokens=0,
39
- **kwargs,
40
- ):
41
- return self.model.forward(
42
- input_ids,
43
- position_ids,
44
- inference_params,
45
- num_last_tokens
46
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- class MambaModelForCausalLM(MambaModel, GenerationMixin):
49
- def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  self,
51
- input_ids,
52
- max_length: int = 2048,
53
- top_k: int = 1,
54
- top_p: float = 0.0,
55
- temperature: float = 1.0,
56
- return_dict_in_generate: bool = False,
57
- output_scores: bool = False,
58
- repetition_penalty: float = 1.0,
59
- eos_token_id: Optional[int] = None,
60
- teacher_outputs: Optional[torch.Tensor] = None,
61
- vocab_size: Optional[int] = None,
62
- cg: bool = False,
63
- enable_timing: bool = False,
64
- streamer: Optional[TextStreamer] = None,
65
  **kwargs,
66
- ):
67
- return self.model.generate(
68
- input_ids=input_ids,
69
- max_length=max_length,
70
- top_k=top_k,
71
- top_p=top_p,
72
- temperature=temperature,
73
- return_dict_in_generate=return_dict_in_generate,
74
- output_scores=output_scores,
75
- repetition_penalty=repetition_penalty,
76
- eos_token_id=eos_token_id,
77
- teacher_outputs=teacher_outputs,
78
- vocab_size=vocab_size,
79
- cg=cg,
80
- enable_timing=enable_timing,
81
- streamer=streamer,
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+ from functools import partial
7
+ from typing import Dict, Optional, Tuple, Union
8
 
 
9
  import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import transformers
13
+ from einops import einsum, rearrange, repeat
14
+ from torch import FloatTensor, Tensor, nn
15
  from transformers import GenerationMixin, PreTrainedModel
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutput,
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutput,
20
+ ImageClassifierOutput,
21
+ QuestionAnsweringModelOutput,
22
+ SequenceClassifierOutput,
23
+ )
24
+ from trl import PreTrainedModelWrapper
25
 
26
  from .configuration_mamba import MambaConfig
27
 
 
 
28
 
29
+ # class SwiGLU(nn.Module):
30
+ # def forward(self, x, W, V, b, c, beta):
31
+ # return F.silu(x * W + b) * (x * V + c)
32
+
33
+
34
+ # Inspired by:
35
+ # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L31
36
+ # - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L177
37
+ # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L31
38
+ class MambaBlock(nn.Module):
39
+ def __init__(self, config: MambaConfig):
40
+ """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].
41
+ Furthermore, in section E.2.2 of the paper, the authors describe the Mamba block as:
42
+ "[T]he Mamba block is simply the standard SwiGLU block with an extra conv → SSM path added."
43
+ """
44
+ super().__init__()
45
+
46
+ self.config = config
47
+
48
+ self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
49
+
50
+ self.conv1d = nn.Conv1d(
51
+ in_channels=config.d_inner,
52
+ out_channels=config.d_inner,
53
+ bias=config.conv_bias,
54
+ kernel_size=config.d_conv,
55
+ groups=config.d_inner,
56
+ padding=config.d_conv - 1,
57
+ )
58
+
59
+ # x_proj takes in `x` and outputs the input-specific Δ, B, C
60
+ self.x_proj = nn.Linear(
61
+ config.d_inner, config.dt_rank + config.d_state * 2, bias=False
62
+ )
63
+
64
+ # dt_proj projects Δ from dt_rank to d_in
65
+ self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
66
+
67
+ A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner)
68
+ self.A_log = nn.Parameter(torch.log(A))
69
+ self.D = nn.Parameter(torch.ones(config.d_inner))
70
+
71
+ self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
72
+ # self.norm = RMSNorm(config.d_model)
73
+
74
+
75
+
76
+ def forward(self, x):
77
+ """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
78
+
79
+ Args:
80
+ x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
81
+
82
+ Returns:
83
+ output: shape (b, l, d)
84
+
85
+ Official Implementation:
86
+ class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
87
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
88
+
89
+ """
90
+ (b, l, d) = x.shape
91
+ # x_copy = x # There was a separate class for residual, I deleted that part and added it here.
92
+ # x = self.norm(x)
93
+ x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
94
+ (x, res) = x_and_res.split(
95
+ split_size=[self.config.d_inner, self.config.d_inner], dim=-1
96
+ )
97
+
98
+ x = rearrange(x, "b l d_in -> b d_in l")
99
+ x = self.conv1d(x)[:, :, :l]
100
+ x = rearrange(x, "b d_in l -> b l d_in")
101
+
102
+ x = F.silu(x)
103
+
104
+ y = self.ssm(x)
105
+
106
+ y = y * F.silu(res) # SwiGLU: Swish_β(xW + b) ⊗ (xV + c) => torch.kron(F.silu(xW + b), xV + c) => torch.kron(F.silu(res), y)
107
+
108
+ output = self.out_proj(y) # output = self.out_proj(y) + x_copy
109
+
110
+ # "the Mamba block is simply the standard SwiGLU block with an extra 𝖼𝗈𝗇𝗏 → 𝖲𝖲𝖬 path added"
111
+
112
+ return output
113
+
114
+ def ssm(self, x):
115
+ """Runs the SSM. See:
116
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
117
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
118
+
119
+ Args:
120
+ x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
121
+
122
+ Returns:
123
+ output: shape (b, l, d_in)
124
+
125
+ Official Implementation:
126
+ mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
127
+
128
+ """
129
+ (d_in, n) = self.A_log.shape
130
+
131
+ # Compute ∆ A B C D, the state space parameters.
132
+ # A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
133
+ # ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
134
+ # and is why Mamba is called **selective** state spaces)
135
+
136
+ A = -torch.exp(self.A_log.float()) # shape (d_in, n)
137
+ D = self.D.float()
138
+
139
+ x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
140
+
141
+ (delta, B, C) = x_dbl.split(
142
+ split_size=[self.config.dt_rank, n, n], dim=-1
143
+ ) # delta: (b, l, dt_rank). B, C: (b, l, n)
144
+ delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
145
+
146
+ y = self.selective_scan(
147
+ x, delta, A, B, C, D
148
+ ) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
149
+
150
+ return y
151
+
152
+ def selective_scan(self, u, delta, A, B, C, D):
153
+ """Does selective scan algorithm. See:
154
+ - Section 2 State Space Models in the Mamba paper [1]
155
+ - Algorithm 2 in Section 3.2 in the Mamba paper [1]
156
+ - run_SSM(A, B, C, u) in The Annotated S4 [2]
157
+
158
+ This is the classic discrete state space formula:
159
+ x(t + 1) = Ax(t) + Bu(t)
160
+ y(t) = Cx(t) + Du(t)
161
+ except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
162
+
163
+ Args:
164
+ u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
165
+ delta: shape (b, l, d_in)
166
+ A: shape (d_in, n)
167
+ B: shape (b, l, n)
168
+ C: shape (b, l, n)
169
+ D: shape (d_in,)
170
+
171
+ Returns:
172
+ output: shape (b, l, d_in)
173
+
174
+ Official Implementation:
175
+ selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
176
+ Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
177
+
178
+ """
179
+ (b, l, d_in) = u.shape
180
+ n = A.shape[1]
181
+
182
+ # Discretize continuous parameters (A, B)
183
+ # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
184
+ # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
185
+ # "A is the more important term and the performance doesn't change much with the simplification on B"
186
+ deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
187
+ deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n")
188
+
189
+ # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
190
+ # Note that the below is sequential, while the official implementation does a much faster parallel scan that
191
+ # is additionally hardware-aware (like FlashAttention).
192
+ x = torch.zeros((b, d_in, n), device=deltaA.device)
193
+ ys = []
194
+
195
+ for i in range(l):
196
+ x = deltaA[:, i] * x + deltaB_u[:, i]
197
+ y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
198
+ ys.append(y)
199
+
200
+ y = torch.stack(ys, dim=1) # shape (b, l, d_in)
201
+
202
+ y = y + u * D
203
+
204
+ return y
205
+
206
+
207
+ # Inspired by:
208
+ # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L19
209
+ # - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
210
+ # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/ops/triton/layernorm.py#L481
211
+ class RMSNorm(nn.Module):
212
+ def __init__(self, d_model: int, eps: float = 1e-5):
213
+ super().__init__()
214
+
215
+ self.eps = eps
216
+ self.weight = nn.Parameter(torch.ones(d_model))
217
+
218
+ def forward(self, x):
219
+ output = (
220
+ x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
221
+ )
222
+
223
+ return output
224
+
225
+
226
+ class ResidualBlock(
227
+ nn.Module
228
+ ): # Copied and modified from https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L143
229
+ def __init__(self, config: MambaConfig):
230
+ """Simple block wrapping Mamba block with normalization and residual connection."""
231
+ super().__init__()
232
+
233
+ # self.args = args
234
+ self.mixer = MambaBlock(config)
235
+ self.norm = RMSNorm(config.d_model)
236
+
237
+ def forward(self, x):
238
+ """
239
+ Args:
240
+ x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
241
+
242
+ Returns:
243
+ output: shape (b, l, d)
244
+
245
+ Official Implementation:
246
+ Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
247
+
248
+ Note: the official repo chains residual blocks that look like
249
+ [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
250
+ where the first Add is a no-op. This is purely for performance reasons as this
251
+ allows them to fuse the Add->Norm.
252
+
253
+ We instead implement our blocks as the more familiar, simpler, and numerically equivalent
254
+ [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
255
+
256
+ """
257
+ output = self.mixer(self.norm(x)) + x
258
+
259
+ return output
260
+
261
+ # Inspired by:
262
+ # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L181
263
+ # class MambaPretrainedModel(PreTrainedModel, nn.Module):
264
+ class MambaPretrainedModel(PreTrainedModel):
265
+ r"""
266
+ Base class for all models.
267
+
268
+ [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
269
+ downloading and saving models as well as a few methods common to all models to:
270
+
271
+ - resize the input embeddings,
272
+ - prune heads in the self-attention heads.
273
+
274
+ Class attributes (overridden by derived classes):
275
+
276
+ - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
277
+ for this model architecture.
278
+ - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
279
+ taking as arguments:
280
+
281
+ - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
282
+ - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
283
+ - **path** (`str`) -- A path to the TensorFlow checkpoint.
284
+
285
+ - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
286
+ classes of the same architecture adding modules on top of the base model.
287
+ - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
288
+ - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
289
+ models, `pixel_values` for vision models and `input_values` for speech models).
290
+ """
291
+
292
+ config_class = MambaConfig # TODO: Build on top of MambaConfig?
293
+ # base_model_prefix = "backbone"
294
+ base_model_prefix = "mamba"
295
+ main_input_name = "input_ids"
296
+ model_tags = None
297
+
298
+ _auto_class = None
299
+ _no_split_modules = ["MambaBlock"]
300
+ _skip_keys_device_placement = None
301
+ _keep_in_fp32_modules = None
302
+
303
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
304
+ # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
305
+ _keys_to_ignore_on_load_missing = None
306
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of
307
+ # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
308
+ # warnings.
309
+ _keys_to_ignore_on_load_unexpected = None
310
+ # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
311
+ # trained, but which are either deterministic or tied variables)
312
+ _keys_to_ignore_on_save = None
313
+ # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
314
+ _tied_weights_keys = None
315
+
316
+ is_parallelizable = False
317
+ supports_gradient_checkpointing = True
318
+
319
+ # Flash Attention 2 support
320
+ _supports_flash_attn_2 = False
321
+
322
+ # SDPA support
323
+ _supports_sdpa = False
324
+
325
+ # Has support for a `Cache` instance as `past_key_values`
326
+ _supports_cache_class = False
327
+
328
+ def __init__(self, *inputs, **kwargs):
329
+ super().__init__(*inputs, **kwargs)
330
+
331
+ # https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L54
332
+ def _init_weights(
333
  self,
334
+ module,
335
+ initializer_range=0.02, # Now only used for embedding layer.
336
+ rescale_prenorm_residual=True,
337
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
 
338
  ):
339
+ if isinstance(module, nn.Linear):
340
+ if module.bias is not None:
341
+ if not getattr(module.bias, "_no_reinit", False):
342
+ nn.init.zeros_(module.bias)
343
+
344
+ elif isinstance(module, nn.Embedding):
345
+ nn.init.normal_(module.weight, std=initializer_range)
346
+
347
+ if rescale_prenorm_residual:
348
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
349
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
350
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
351
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
352
+ #
353
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
354
+ for name, p in module.named_parameters():
355
+ if name in ["out_proj.weight", "fc2.weight"]:
356
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
357
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
358
+ # We need to reinit p since this code could be called multiple times
359
+ # Having just p *= scale would repeatedly scale it down
360
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
361
+ with torch.no_grad():
362
+ p /= math.sqrt(n_residuals_per_layer * self.config.n_layer)
363
+
364
+ # def _set_gradient_checkpointing(self, module, value=False):
365
+ # if isinstance(module, GPT2Model):
366
+ # module.gradient_checkpointing = value
367
+
368
+
369
+ class MambaModel(MambaPretrainedModel):
370
+ def __init__(
371
+ self, config: MambaConfig = MambaConfig(), **kwargs
372
+ ) -> None:
373
+ """Full Mamba model.
374
+ Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`]
375
+ Args:
376
+ config: MambaConfig
377
+ """
378
  super().__init__(
379
  config,
380
  **kwargs,
381
  )
382
 
383
+ # self.embedding = nn.Embedding(
384
+ # num_embeddings=config.vocab_size,
385
+ # embedding_dim=config.d_model,
386
+ # )
387
+
388
+
389
+ self.embedding = nn.Embedding(
390
+ num_embeddings=config.vocab_size,
391
+ embedding_dim=config.d_model,
392
  )
393
 
394
+ self.layers = nn.ModuleList(
395
+ [ResidualBlock(config) for _ in range(self.config.n_layer)]
 
 
 
 
 
 
 
 
 
 
 
396
  )
397
+ # self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
398
+ # # self.norm_f = RMSNorm(d_model=embedding_dim)
399
+ self.norm_f = RMSNorm(config.d_model)
400
+
401
+ # self.gradient_checkpointing = False
402
+ # # self.post_init()
403
+
404
+ # Initialize weights and apply final processing
405
+ self.post_init()
406
+
407
+ # def _init_weights(self, module):
408
+ # std = 0.02
409
+
410
+ # if isinstance(module, (nn.Linear, nn.Conv1d)):
411
+ # module.weight.data.normal_(mean=0.0, std=std)
412
+
413
+ # if module.bias is not None:
414
+ # module.bias.data.zero_()
415
+
416
+ # elif isinstance(module, nn.Embedding):
417
+ # module.weight.data.normal_(mean=0.0, std=std)
418
+
419
+ # if module.padding_idx is not None:
420
+ # module.weight.data[module.padding_idx].zero_()
421
+
422
+ # Inspired by:
423
+ # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L198
424
+ # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L86
425
+ # class MambaModel(MambaPretrainedModel):
426
+ # def __init__(
427
+ # self,
428
+ # config: MambaConfig = MambaConfig(),
429
+ # **kwargs,
430
+ # ) -> None:
431
+ # super().__init__(
432
+ # config,
433
+ # **kwargs,
434
+ # )
435
+
436
+ # self.embedding = nn.Embedding(
437
+ # num_embeddings=config.vocab_size,
438
+ # embedding_dim=config.d_model,
439
+ # )
440
+
441
+ # # # self.layers = nn.ModuleList(
442
+ # # # [ResidualBlock(args=model_args) for _ in range(model_args.n_layer)]
443
+ # # # )
444
+ # self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
445
+ # # # self.norm_f = RMSNorm(d_model=embedding_dim)
446
+ # self.norm_f = RMSNorm(config.d_model)
447
+
448
+ # # self.gradient_checkpointing = False
449
+ # # # self.post_init()
450
+
451
+ # def get_input_embeddings(self):
452
+ # return self.embed_out
453
 
454
+ # def set_input_embeddings(self, value):
455
+ # self.embed_out = value
456
+
457
+ # def forward(
458
+ # self,
459
+ # input_ids: torch.LongTensor = None,
460
+ # output_hidden_states=False,
461
+ # return_dict: Optional[bool] = None,
462
+ # **kwargs,
463
+ # # ) -> BaseModelOutput:
464
+ # ) -> Union[Tuple, BaseModelOutputWithPast]:
465
+ # batch_size = input_ids.shape[0]
466
+ # hidden_size = self.config.hidden_size
467
+ # hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
468
+ # sequence_length = input_ids.shape[1]
469
+ # output_hidden_states = output_hidden_states or self.config.output_hidden_states
470
+
471
+ # last_hidden_state = self.embed_out(input_ids)
472
+ # assert last_hidden_state.shape == (
473
+ # batch_size,
474
+ # sequence_length,
475
+ # hidden_size,
476
+ # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
477
+ # hidden_states += (last_hidden_state,)
478
+
479
+ # for layer in self.layers:
480
+ # last_hidden_state = layer(last_hidden_state)
481
+ # assert last_hidden_state.shape == (
482
+ # batch_size,
483
+ # sequence_length,
484
+ # hidden_size,
485
+ # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
486
+ # hidden_states += (last_hidden_state,)
487
+
488
+ # last_hidden_state = self.norm_f(last_hidden_state)
489
+ # assert last_hidden_state.shape == (
490
+ # batch_size,
491
+ # sequence_length,
492
+ # hidden_size,
493
+ # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
494
+ # hidden_states += (last_hidden_state,)
495
+
496
+ # assert (
497
+ # len(hidden_states) == self.config.n_layer + 2
498
+ # ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
499
+
500
+ # # return BaseModelOutput(
501
+ # return BaseModelOutputWithPast(
502
+ # hidden_states=hidden_states if output_hidden_states else None,
503
+ # last_hidden_state=last_hidden_state,
504
+ # )
505
+
506
+
507
+ # Influences:
508
+ # - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L238
509
+ # - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L176
510
+ # class MambaModelForCausalLM(MambaModel, GenerationMixin):
511
+ # class MambaModelForCausalLM(PreTrainedModel, GenerationMixin):
512
+ # class MambaLMHeadModel(MambaPretrainedModel, GenerationMixin):
513
+ class MambaLMHeadModel(MambaPretrainedModel):
514
+ # _tied_weights_keys = ["lm_head.weight",
515
+
516
+ def __init__(
517
  self,
518
+ config: MambaConfig = MambaConfig(),
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  **kwargs,
520
+ ) -> None:
521
+ super().__init__(
522
+ config,
523
+ **kwargs,
524
+ )
525
+
526
+ self.backbone = MambaModel(
527
+ config=self.config,
528
+ )
529
+
530
+ self.lm_head = nn.Linear(
531
+ in_features=self.config.hidden_size,
532
+ out_features=self.config.vocab_size,
533
+ bias=False,
 
 
534
  )
535
+
536
+ # # self.head.weight = self.backbone.embedding.weight # TODO: there's some logic in GenerationMix that does this
537
+
538
+ # Initialize weights and apply final processing
539
+ self.post_init()
540
+
541
+ # # def forward(
542
+ # # self, input_ids, output_hidden_states=False, **kwargs
543
+ # # ) -> CausalLMOutput:
544
+ # # batch_size = input_ids.shape[0]
545
+ # # sequence_length = input_ids.shape[1]
546
+ # # vocab_size = self.config.vocab_size
547
+ # # output_hidden_states = output_hidden_states or self.config.output_hidden_states
548
+
549
+ # # outputs = self.backbone(
550
+ # # input_ids=input_ids,
551
+ # # output_hidden_states=output_hidden_states,
552
+ # # )
553
+
554
+ # # last_hidden_state = outputs.last_hidden_state
555
+
556
+ # # logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
557
+ # # self.lm_head(
558
+ # # last_hidden_state,
559
+ # # )
560
+ # # )
561
+
562
+ # # return CausalLMOutput(
563
+ # # hidden_states=outputs.hidden_states if output_hidden_states else None,
564
+ # # logits=logits,
565
+ # # )
566
+
567
+ # # def prepare_inputs_for_generation(
568
+ # # self, input_ids, attention_mask=None, **model_kwargs
569
+ # # ):
570
+ # # return {
571
+ # # "input_ids": input_ids,
572
+ # # }
573
+
574
+
575
+ # class MultimodalMambaModelForCausalLMWithValueHead(PreTrainedModelWrapper):
576
+ # lm_head_namings: Tuple[str, str] = ("lm_head", "embed_out")
577
+ # transformers_parent_class: transformers.PreTrainedModel = transformers.AutoModelForCausalLM
578
+
579
+ # # def __init__(
580
+ # # self,
581
+ # # config: MultimodalMambaConfig = MultimodalMambaConfig(),
582
+ # # **kwargs,
583
+ # # ) -> None:
584
+ # # super().__init__(
585
+ # # config,
586
+ # # **kwargs,
587
+ # # )
588
+
589
+ # # self.model = MultimodalMambaModelForCausalLM(
590
+ # # config=config,
591
+ # # )
592
+
593
+ # # self.value_head = nn.Linear(
594
+ # # in_features=config.embedding_dim,
595
+ # # out_features=1,
596
+ # # bias=False,
597
+ # # )
598
+
599
+ # # def forward(
600
+ # # self, input_ids, output_hidden_states=False, **kwargs
601
+ # # ) -> CausalLMOutput:
602
+ # # outputs = self.model(
603
+ # # input_ids=input_ids,
604
+ # # output_hidden_states=output_hidden_states,
605
+ # # )
606
+
607
+ # # last_hidden_state = outputs.last_hidden_state
608
+
609
+ # # value: torch.FloatTensor[batch_size, sequence_length, 1] = self.value_head(
610
+ # # last_hidden_state,
611
+ # # )
612
+
613
+ # # return CausalLMOutput(
614
+ # # hidden_states=outputs.hidden_states if output_hidden_states else None,
615
+ # # logits=outputs.logits,
616
+ # # value=value,
617
+ # # )