Upload model
Browse files- config.json +6 -1
- model.safetensors +2 -2
- modeling_mamba.py +596 -61
config.json
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
4 |
},
|
5 |
"bias": false,
|
6 |
"conv_bias": true,
|
@@ -14,6 +18,7 @@
|
|
14 |
"model_type": "mamba",
|
15 |
"n_layer": 24,
|
16 |
"pad_vocab_size_multiple": 8,
|
|
|
17 |
"transformers_version": "4.37.2",
|
18 |
"vocab_size": 50280
|
19 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MambaLMHeadModel"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaLMHeadModel"
|
8 |
},
|
9 |
"bias": false,
|
10 |
"conv_bias": true,
|
|
|
18 |
"model_type": "mamba",
|
19 |
"n_layer": 24,
|
20 |
"pad_vocab_size_multiple": 8,
|
21 |
+
"torch_dtype": "float32",
|
22 |
"transformers_version": "4.37.2",
|
23 |
"vocab_size": 50280
|
24 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1bd3ca62665de4bfabff9d443f87a11090a10e505c0ccb56e6f9ca495b6e05bd
|
3 |
+
size 671027808
|
modeling_mamba.py
CHANGED
@@ -1,82 +1,617 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
5 |
from transformers import GenerationMixin, PreTrainedModel
|
6 |
-
from transformers.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
from .configuration_mamba import MambaConfig
|
9 |
|
10 |
-
class MambaModel(PreTrainedModel):
|
11 |
-
config_class = MambaConfig
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
self,
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
**kwargs,
|
20 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
super().__init__(
|
22 |
config,
|
23 |
**kwargs,
|
24 |
)
|
25 |
|
26 |
-
self.
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
)
|
32 |
|
33 |
-
|
34 |
-
self
|
35 |
-
input_ids,
|
36 |
-
position_ids=None,
|
37 |
-
inference_params=None,
|
38 |
-
num_last_tokens=0,
|
39 |
-
**kwargs,
|
40 |
-
):
|
41 |
-
return self.model.forward(
|
42 |
-
input_ids,
|
43 |
-
position_ids,
|
44 |
-
inference_params,
|
45 |
-
num_last_tokens
|
46 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
self,
|
51 |
-
|
52 |
-
max_length: int = 2048,
|
53 |
-
top_k: int = 1,
|
54 |
-
top_p: float = 0.0,
|
55 |
-
temperature: float = 1.0,
|
56 |
-
return_dict_in_generate: bool = False,
|
57 |
-
output_scores: bool = False,
|
58 |
-
repetition_penalty: float = 1.0,
|
59 |
-
eos_token_id: Optional[int] = None,
|
60 |
-
teacher_outputs: Optional[torch.Tensor] = None,
|
61 |
-
vocab_size: Optional[int] = None,
|
62 |
-
cg: bool = False,
|
63 |
-
enable_timing: bool = False,
|
64 |
-
streamer: Optional[TextStreamer] = None,
|
65 |
**kwargs,
|
66 |
-
):
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
enable_timing=enable_timing,
|
81 |
-
streamer=streamer,
|
82 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from collections import namedtuple
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from functools import partial
|
7 |
+
from typing import Dict, Optional, Tuple, Union
|
8 |
|
|
|
9 |
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import transformers
|
13 |
+
from einops import einsum, rearrange, repeat
|
14 |
+
from torch import FloatTensor, Tensor, nn
|
15 |
from transformers import GenerationMixin, PreTrainedModel
|
16 |
+
from transformers.modeling_outputs import (
|
17 |
+
BaseModelOutput,
|
18 |
+
BaseModelOutputWithPast,
|
19 |
+
CausalLMOutput,
|
20 |
+
ImageClassifierOutput,
|
21 |
+
QuestionAnsweringModelOutput,
|
22 |
+
SequenceClassifierOutput,
|
23 |
+
)
|
24 |
+
from trl import PreTrainedModelWrapper
|
25 |
|
26 |
from .configuration_mamba import MambaConfig
|
27 |
|
|
|
|
|
28 |
|
29 |
+
# class SwiGLU(nn.Module):
|
30 |
+
# def forward(self, x, W, V, b, c, beta):
|
31 |
+
# return F.silu(x * W + b) * (x * V + c)
|
32 |
+
|
33 |
+
|
34 |
+
# Inspired by:
|
35 |
+
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L31
|
36 |
+
# - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L177
|
37 |
+
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L31
|
38 |
+
class MambaBlock(nn.Module):
|
39 |
+
def __init__(self, config: MambaConfig):
|
40 |
+
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].
|
41 |
+
Furthermore, in section E.2.2 of the paper, the authors describe the Mamba block as:
|
42 |
+
"[T]he Mamba block is simply the standard SwiGLU block with an extra conv → SSM path added."
|
43 |
+
"""
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.config = config
|
47 |
+
|
48 |
+
self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
|
49 |
+
|
50 |
+
self.conv1d = nn.Conv1d(
|
51 |
+
in_channels=config.d_inner,
|
52 |
+
out_channels=config.d_inner,
|
53 |
+
bias=config.conv_bias,
|
54 |
+
kernel_size=config.d_conv,
|
55 |
+
groups=config.d_inner,
|
56 |
+
padding=config.d_conv - 1,
|
57 |
+
)
|
58 |
+
|
59 |
+
# x_proj takes in `x` and outputs the input-specific Δ, B, C
|
60 |
+
self.x_proj = nn.Linear(
|
61 |
+
config.d_inner, config.dt_rank + config.d_state * 2, bias=False
|
62 |
+
)
|
63 |
+
|
64 |
+
# dt_proj projects Δ from dt_rank to d_in
|
65 |
+
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
|
66 |
+
|
67 |
+
A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner)
|
68 |
+
self.A_log = nn.Parameter(torch.log(A))
|
69 |
+
self.D = nn.Parameter(torch.ones(config.d_inner))
|
70 |
+
|
71 |
+
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
|
72 |
+
# self.norm = RMSNorm(config.d_model)
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
|
78 |
+
|
79 |
+
Args:
|
80 |
+
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
output: shape (b, l, d)
|
84 |
+
|
85 |
+
Official Implementation:
|
86 |
+
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
|
87 |
+
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
88 |
+
|
89 |
+
"""
|
90 |
+
(b, l, d) = x.shape
|
91 |
+
# x_copy = x # There was a separate class for residual, I deleted that part and added it here.
|
92 |
+
# x = self.norm(x)
|
93 |
+
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
|
94 |
+
(x, res) = x_and_res.split(
|
95 |
+
split_size=[self.config.d_inner, self.config.d_inner], dim=-1
|
96 |
+
)
|
97 |
+
|
98 |
+
x = rearrange(x, "b l d_in -> b d_in l")
|
99 |
+
x = self.conv1d(x)[:, :, :l]
|
100 |
+
x = rearrange(x, "b d_in l -> b l d_in")
|
101 |
+
|
102 |
+
x = F.silu(x)
|
103 |
+
|
104 |
+
y = self.ssm(x)
|
105 |
+
|
106 |
+
y = y * F.silu(res) # SwiGLU: Swish_β(xW + b) ⊗ (xV + c) => torch.kron(F.silu(xW + b), xV + c) => torch.kron(F.silu(res), y)
|
107 |
+
|
108 |
+
output = self.out_proj(y) # output = self.out_proj(y) + x_copy
|
109 |
+
|
110 |
+
# "the Mamba block is simply the standard SwiGLU block with an extra 𝖼𝗈𝗇𝗏 → 𝖲𝖲𝖬 path added"
|
111 |
+
|
112 |
+
return output
|
113 |
+
|
114 |
+
def ssm(self, x):
|
115 |
+
"""Runs the SSM. See:
|
116 |
+
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
|
117 |
+
- run_SSM(A, B, C, u) in The Annotated S4 [2]
|
118 |
+
|
119 |
+
Args:
|
120 |
+
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
output: shape (b, l, d_in)
|
124 |
+
|
125 |
+
Official Implementation:
|
126 |
+
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
|
127 |
+
|
128 |
+
"""
|
129 |
+
(d_in, n) = self.A_log.shape
|
130 |
+
|
131 |
+
# Compute ∆ A B C D, the state space parameters.
|
132 |
+
# A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
|
133 |
+
# ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
|
134 |
+
# and is why Mamba is called **selective** state spaces)
|
135 |
+
|
136 |
+
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
|
137 |
+
D = self.D.float()
|
138 |
+
|
139 |
+
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
|
140 |
+
|
141 |
+
(delta, B, C) = x_dbl.split(
|
142 |
+
split_size=[self.config.dt_rank, n, n], dim=-1
|
143 |
+
) # delta: (b, l, dt_rank). B, C: (b, l, n)
|
144 |
+
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
|
145 |
+
|
146 |
+
y = self.selective_scan(
|
147 |
+
x, delta, A, B, C, D
|
148 |
+
) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
|
149 |
+
|
150 |
+
return y
|
151 |
+
|
152 |
+
def selective_scan(self, u, delta, A, B, C, D):
|
153 |
+
"""Does selective scan algorithm. See:
|
154 |
+
- Section 2 State Space Models in the Mamba paper [1]
|
155 |
+
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
|
156 |
+
- run_SSM(A, B, C, u) in The Annotated S4 [2]
|
157 |
+
|
158 |
+
This is the classic discrete state space formula:
|
159 |
+
x(t + 1) = Ax(t) + Bu(t)
|
160 |
+
y(t) = Cx(t) + Du(t)
|
161 |
+
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
|
162 |
+
|
163 |
+
Args:
|
164 |
+
u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
|
165 |
+
delta: shape (b, l, d_in)
|
166 |
+
A: shape (d_in, n)
|
167 |
+
B: shape (b, l, n)
|
168 |
+
C: shape (b, l, n)
|
169 |
+
D: shape (d_in,)
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
output: shape (b, l, d_in)
|
173 |
+
|
174 |
+
Official Implementation:
|
175 |
+
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
|
176 |
+
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
|
177 |
+
|
178 |
+
"""
|
179 |
+
(b, l, d_in) = u.shape
|
180 |
+
n = A.shape[1]
|
181 |
+
|
182 |
+
# Discretize continuous parameters (A, B)
|
183 |
+
# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
|
184 |
+
# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
|
185 |
+
# "A is the more important term and the performance doesn't change much with the simplification on B"
|
186 |
+
deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
|
187 |
+
deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n")
|
188 |
+
|
189 |
+
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
|
190 |
+
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
|
191 |
+
# is additionally hardware-aware (like FlashAttention).
|
192 |
+
x = torch.zeros((b, d_in, n), device=deltaA.device)
|
193 |
+
ys = []
|
194 |
+
|
195 |
+
for i in range(l):
|
196 |
+
x = deltaA[:, i] * x + deltaB_u[:, i]
|
197 |
+
y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
|
198 |
+
ys.append(y)
|
199 |
+
|
200 |
+
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
|
201 |
+
|
202 |
+
y = y + u * D
|
203 |
+
|
204 |
+
return y
|
205 |
+
|
206 |
+
|
207 |
+
# Inspired by:
|
208 |
+
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L19
|
209 |
+
# - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
|
210 |
+
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/ops/triton/layernorm.py#L481
|
211 |
+
class RMSNorm(nn.Module):
|
212 |
+
def __init__(self, d_model: int, eps: float = 1e-5):
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
self.eps = eps
|
216 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
output = (
|
220 |
+
x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
|
221 |
+
)
|
222 |
+
|
223 |
+
return output
|
224 |
+
|
225 |
+
|
226 |
+
class ResidualBlock(
|
227 |
+
nn.Module
|
228 |
+
): # Copied and modified from https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L143
|
229 |
+
def __init__(self, config: MambaConfig):
|
230 |
+
"""Simple block wrapping Mamba block with normalization and residual connection."""
|
231 |
+
super().__init__()
|
232 |
+
|
233 |
+
# self.args = args
|
234 |
+
self.mixer = MambaBlock(config)
|
235 |
+
self.norm = RMSNorm(config.d_model)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
output: shape (b, l, d)
|
244 |
+
|
245 |
+
Official Implementation:
|
246 |
+
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
|
247 |
+
|
248 |
+
Note: the official repo chains residual blocks that look like
|
249 |
+
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
|
250 |
+
where the first Add is a no-op. This is purely for performance reasons as this
|
251 |
+
allows them to fuse the Add->Norm.
|
252 |
+
|
253 |
+
We instead implement our blocks as the more familiar, simpler, and numerically equivalent
|
254 |
+
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
|
255 |
+
|
256 |
+
"""
|
257 |
+
output = self.mixer(self.norm(x)) + x
|
258 |
+
|
259 |
+
return output
|
260 |
+
|
261 |
+
# Inspired by:
|
262 |
+
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L181
|
263 |
+
# class MambaPretrainedModel(PreTrainedModel, nn.Module):
|
264 |
+
class MambaPretrainedModel(PreTrainedModel):
|
265 |
+
r"""
|
266 |
+
Base class for all models.
|
267 |
+
|
268 |
+
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
|
269 |
+
downloading and saving models as well as a few methods common to all models to:
|
270 |
+
|
271 |
+
- resize the input embeddings,
|
272 |
+
- prune heads in the self-attention heads.
|
273 |
+
|
274 |
+
Class attributes (overridden by derived classes):
|
275 |
+
|
276 |
+
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
|
277 |
+
for this model architecture.
|
278 |
+
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
279 |
+
taking as arguments:
|
280 |
+
|
281 |
+
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
282 |
+
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
|
283 |
+
- **path** (`str`) -- A path to the TensorFlow checkpoint.
|
284 |
+
|
285 |
+
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
286 |
+
classes of the same architecture adding modules on top of the base model.
|
287 |
+
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
|
288 |
+
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
|
289 |
+
models, `pixel_values` for vision models and `input_values` for speech models).
|
290 |
+
"""
|
291 |
+
|
292 |
+
config_class = MambaConfig # TODO: Build on top of MambaConfig?
|
293 |
+
# base_model_prefix = "backbone"
|
294 |
+
base_model_prefix = "mamba"
|
295 |
+
main_input_name = "input_ids"
|
296 |
+
model_tags = None
|
297 |
+
|
298 |
+
_auto_class = None
|
299 |
+
_no_split_modules = ["MambaBlock"]
|
300 |
+
_skip_keys_device_placement = None
|
301 |
+
_keep_in_fp32_modules = None
|
302 |
+
|
303 |
+
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
304 |
+
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
305 |
+
_keys_to_ignore_on_load_missing = None
|
306 |
+
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
|
307 |
+
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
|
308 |
+
# warnings.
|
309 |
+
_keys_to_ignore_on_load_unexpected = None
|
310 |
+
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
|
311 |
+
# trained, but which are either deterministic or tied variables)
|
312 |
+
_keys_to_ignore_on_save = None
|
313 |
+
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
|
314 |
+
_tied_weights_keys = None
|
315 |
+
|
316 |
+
is_parallelizable = False
|
317 |
+
supports_gradient_checkpointing = True
|
318 |
+
|
319 |
+
# Flash Attention 2 support
|
320 |
+
_supports_flash_attn_2 = False
|
321 |
+
|
322 |
+
# SDPA support
|
323 |
+
_supports_sdpa = False
|
324 |
+
|
325 |
+
# Has support for a `Cache` instance as `past_key_values`
|
326 |
+
_supports_cache_class = False
|
327 |
+
|
328 |
+
def __init__(self, *inputs, **kwargs):
|
329 |
+
super().__init__(*inputs, **kwargs)
|
330 |
+
|
331 |
+
# https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L54
|
332 |
+
def _init_weights(
|
333 |
self,
|
334 |
+
module,
|
335 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
336 |
+
rescale_prenorm_residual=True,
|
337 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
|
|
338 |
):
|
339 |
+
if isinstance(module, nn.Linear):
|
340 |
+
if module.bias is not None:
|
341 |
+
if not getattr(module.bias, "_no_reinit", False):
|
342 |
+
nn.init.zeros_(module.bias)
|
343 |
+
|
344 |
+
elif isinstance(module, nn.Embedding):
|
345 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
346 |
+
|
347 |
+
if rescale_prenorm_residual:
|
348 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
349 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
350 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
351 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
352 |
+
#
|
353 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
354 |
+
for name, p in module.named_parameters():
|
355 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
356 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
357 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
358 |
+
# We need to reinit p since this code could be called multiple times
|
359 |
+
# Having just p *= scale would repeatedly scale it down
|
360 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
361 |
+
with torch.no_grad():
|
362 |
+
p /= math.sqrt(n_residuals_per_layer * self.config.n_layer)
|
363 |
+
|
364 |
+
# def _set_gradient_checkpointing(self, module, value=False):
|
365 |
+
# if isinstance(module, GPT2Model):
|
366 |
+
# module.gradient_checkpointing = value
|
367 |
+
|
368 |
+
|
369 |
+
class MambaModel(MambaPretrainedModel):
|
370 |
+
def __init__(
|
371 |
+
self, config: MambaConfig = MambaConfig(), **kwargs
|
372 |
+
) -> None:
|
373 |
+
"""Full Mamba model.
|
374 |
+
Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`]
|
375 |
+
Args:
|
376 |
+
config: MambaConfig
|
377 |
+
"""
|
378 |
super().__init__(
|
379 |
config,
|
380 |
**kwargs,
|
381 |
)
|
382 |
|
383 |
+
# self.embedding = nn.Embedding(
|
384 |
+
# num_embeddings=config.vocab_size,
|
385 |
+
# embedding_dim=config.d_model,
|
386 |
+
# )
|
387 |
+
|
388 |
+
|
389 |
+
self.embedding = nn.Embedding(
|
390 |
+
num_embeddings=config.vocab_size,
|
391 |
+
embedding_dim=config.d_model,
|
392 |
)
|
393 |
|
394 |
+
self.layers = nn.ModuleList(
|
395 |
+
[ResidualBlock(config) for _ in range(self.config.n_layer)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
)
|
397 |
+
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
398 |
+
# # self.norm_f = RMSNorm(d_model=embedding_dim)
|
399 |
+
self.norm_f = RMSNorm(config.d_model)
|
400 |
+
|
401 |
+
# self.gradient_checkpointing = False
|
402 |
+
# # self.post_init()
|
403 |
+
|
404 |
+
# Initialize weights and apply final processing
|
405 |
+
self.post_init()
|
406 |
+
|
407 |
+
# def _init_weights(self, module):
|
408 |
+
# std = 0.02
|
409 |
+
|
410 |
+
# if isinstance(module, (nn.Linear, nn.Conv1d)):
|
411 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
412 |
+
|
413 |
+
# if module.bias is not None:
|
414 |
+
# module.bias.data.zero_()
|
415 |
+
|
416 |
+
# elif isinstance(module, nn.Embedding):
|
417 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
418 |
+
|
419 |
+
# if module.padding_idx is not None:
|
420 |
+
# module.weight.data[module.padding_idx].zero_()
|
421 |
+
|
422 |
+
# Inspired by:
|
423 |
+
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L198
|
424 |
+
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L86
|
425 |
+
# class MambaModel(MambaPretrainedModel):
|
426 |
+
# def __init__(
|
427 |
+
# self,
|
428 |
+
# config: MambaConfig = MambaConfig(),
|
429 |
+
# **kwargs,
|
430 |
+
# ) -> None:
|
431 |
+
# super().__init__(
|
432 |
+
# config,
|
433 |
+
# **kwargs,
|
434 |
+
# )
|
435 |
+
|
436 |
+
# self.embedding = nn.Embedding(
|
437 |
+
# num_embeddings=config.vocab_size,
|
438 |
+
# embedding_dim=config.d_model,
|
439 |
+
# )
|
440 |
+
|
441 |
+
# # # self.layers = nn.ModuleList(
|
442 |
+
# # # [ResidualBlock(args=model_args) for _ in range(model_args.n_layer)]
|
443 |
+
# # # )
|
444 |
+
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
445 |
+
# # # self.norm_f = RMSNorm(d_model=embedding_dim)
|
446 |
+
# self.norm_f = RMSNorm(config.d_model)
|
447 |
+
|
448 |
+
# # self.gradient_checkpointing = False
|
449 |
+
# # # self.post_init()
|
450 |
+
|
451 |
+
# def get_input_embeddings(self):
|
452 |
+
# return self.embed_out
|
453 |
|
454 |
+
# def set_input_embeddings(self, value):
|
455 |
+
# self.embed_out = value
|
456 |
+
|
457 |
+
# def forward(
|
458 |
+
# self,
|
459 |
+
# input_ids: torch.LongTensor = None,
|
460 |
+
# output_hidden_states=False,
|
461 |
+
# return_dict: Optional[bool] = None,
|
462 |
+
# **kwargs,
|
463 |
+
# # ) -> BaseModelOutput:
|
464 |
+
# ) -> Union[Tuple, BaseModelOutputWithPast]:
|
465 |
+
# batch_size = input_ids.shape[0]
|
466 |
+
# hidden_size = self.config.hidden_size
|
467 |
+
# hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
|
468 |
+
# sequence_length = input_ids.shape[1]
|
469 |
+
# output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
470 |
+
|
471 |
+
# last_hidden_state = self.embed_out(input_ids)
|
472 |
+
# assert last_hidden_state.shape == (
|
473 |
+
# batch_size,
|
474 |
+
# sequence_length,
|
475 |
+
# hidden_size,
|
476 |
+
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
477 |
+
# hidden_states += (last_hidden_state,)
|
478 |
+
|
479 |
+
# for layer in self.layers:
|
480 |
+
# last_hidden_state = layer(last_hidden_state)
|
481 |
+
# assert last_hidden_state.shape == (
|
482 |
+
# batch_size,
|
483 |
+
# sequence_length,
|
484 |
+
# hidden_size,
|
485 |
+
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
486 |
+
# hidden_states += (last_hidden_state,)
|
487 |
+
|
488 |
+
# last_hidden_state = self.norm_f(last_hidden_state)
|
489 |
+
# assert last_hidden_state.shape == (
|
490 |
+
# batch_size,
|
491 |
+
# sequence_length,
|
492 |
+
# hidden_size,
|
493 |
+
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
494 |
+
# hidden_states += (last_hidden_state,)
|
495 |
+
|
496 |
+
# assert (
|
497 |
+
# len(hidden_states) == self.config.n_layer + 2
|
498 |
+
# ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
|
499 |
+
|
500 |
+
# # return BaseModelOutput(
|
501 |
+
# return BaseModelOutputWithPast(
|
502 |
+
# hidden_states=hidden_states if output_hidden_states else None,
|
503 |
+
# last_hidden_state=last_hidden_state,
|
504 |
+
# )
|
505 |
+
|
506 |
+
|
507 |
+
# Influences:
|
508 |
+
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L238
|
509 |
+
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L176
|
510 |
+
# class MambaModelForCausalLM(MambaModel, GenerationMixin):
|
511 |
+
# class MambaModelForCausalLM(PreTrainedModel, GenerationMixin):
|
512 |
+
# class MambaLMHeadModel(MambaPretrainedModel, GenerationMixin):
|
513 |
+
class MambaLMHeadModel(MambaPretrainedModel):
|
514 |
+
# _tied_weights_keys = ["lm_head.weight",
|
515 |
+
|
516 |
+
def __init__(
|
517 |
self,
|
518 |
+
config: MambaConfig = MambaConfig(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
**kwargs,
|
520 |
+
) -> None:
|
521 |
+
super().__init__(
|
522 |
+
config,
|
523 |
+
**kwargs,
|
524 |
+
)
|
525 |
+
|
526 |
+
self.backbone = MambaModel(
|
527 |
+
config=self.config,
|
528 |
+
)
|
529 |
+
|
530 |
+
self.lm_head = nn.Linear(
|
531 |
+
in_features=self.config.hidden_size,
|
532 |
+
out_features=self.config.vocab_size,
|
533 |
+
bias=False,
|
|
|
|
|
534 |
)
|
535 |
+
|
536 |
+
# # self.head.weight = self.backbone.embedding.weight # TODO: there's some logic in GenerationMix that does this
|
537 |
+
|
538 |
+
# Initialize weights and apply final processing
|
539 |
+
self.post_init()
|
540 |
+
|
541 |
+
# # def forward(
|
542 |
+
# # self, input_ids, output_hidden_states=False, **kwargs
|
543 |
+
# # ) -> CausalLMOutput:
|
544 |
+
# # batch_size = input_ids.shape[0]
|
545 |
+
# # sequence_length = input_ids.shape[1]
|
546 |
+
# # vocab_size = self.config.vocab_size
|
547 |
+
# # output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
548 |
+
|
549 |
+
# # outputs = self.backbone(
|
550 |
+
# # input_ids=input_ids,
|
551 |
+
# # output_hidden_states=output_hidden_states,
|
552 |
+
# # )
|
553 |
+
|
554 |
+
# # last_hidden_state = outputs.last_hidden_state
|
555 |
+
|
556 |
+
# # logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
|
557 |
+
# # self.lm_head(
|
558 |
+
# # last_hidden_state,
|
559 |
+
# # )
|
560 |
+
# # )
|
561 |
+
|
562 |
+
# # return CausalLMOutput(
|
563 |
+
# # hidden_states=outputs.hidden_states if output_hidden_states else None,
|
564 |
+
# # logits=logits,
|
565 |
+
# # )
|
566 |
+
|
567 |
+
# # def prepare_inputs_for_generation(
|
568 |
+
# # self, input_ids, attention_mask=None, **model_kwargs
|
569 |
+
# # ):
|
570 |
+
# # return {
|
571 |
+
# # "input_ids": input_ids,
|
572 |
+
# # }
|
573 |
+
|
574 |
+
|
575 |
+
# class MultimodalMambaModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
576 |
+
# lm_head_namings: Tuple[str, str] = ("lm_head", "embed_out")
|
577 |
+
# transformers_parent_class: transformers.PreTrainedModel = transformers.AutoModelForCausalLM
|
578 |
+
|
579 |
+
# # def __init__(
|
580 |
+
# # self,
|
581 |
+
# # config: MultimodalMambaConfig = MultimodalMambaConfig(),
|
582 |
+
# # **kwargs,
|
583 |
+
# # ) -> None:
|
584 |
+
# # super().__init__(
|
585 |
+
# # config,
|
586 |
+
# # **kwargs,
|
587 |
+
# # )
|
588 |
+
|
589 |
+
# # self.model = MultimodalMambaModelForCausalLM(
|
590 |
+
# # config=config,
|
591 |
+
# # )
|
592 |
+
|
593 |
+
# # self.value_head = nn.Linear(
|
594 |
+
# # in_features=config.embedding_dim,
|
595 |
+
# # out_features=1,
|
596 |
+
# # bias=False,
|
597 |
+
# # )
|
598 |
+
|
599 |
+
# # def forward(
|
600 |
+
# # self, input_ids, output_hidden_states=False, **kwargs
|
601 |
+
# # ) -> CausalLMOutput:
|
602 |
+
# # outputs = self.model(
|
603 |
+
# # input_ids=input_ids,
|
604 |
+
# # output_hidden_states=output_hidden_states,
|
605 |
+
# # )
|
606 |
+
|
607 |
+
# # last_hidden_state = outputs.last_hidden_state
|
608 |
+
|
609 |
+
# # value: torch.FloatTensor[batch_size, sequence_length, 1] = self.value_head(
|
610 |
+
# # last_hidden_state,
|
611 |
+
# # )
|
612 |
+
|
613 |
+
# # return CausalLMOutput(
|
614 |
+
# # hidden_states=outputs.hidden_states if output_hidden_states else None,
|
615 |
+
# # logits=outputs.logits,
|
616 |
+
# # value=value,
|
617 |
+
# # )
|