|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import einsum, rearrange, repeat |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
SequenceClassifierOutputWithPast, |
|
) |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
from .configuration_mamba import MambaConfig |
|
|
|
|
|
class MambaRMSNorm(nn.Module): |
|
def __init__(self, d_model: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(d_model)) |
|
|
|
def forward(self, x): |
|
output = ( |
|
x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
) |
|
return output |
|
|
|
|
|
class Mamba(nn.Module): |
|
def __init__(self, config: MambaConfig): |
|
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" |
|
super().__init__() |
|
self.config = config |
|
|
|
self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias) |
|
|
|
self.conv1d = nn.Conv1d( |
|
in_channels=config.d_inner, |
|
out_channels=config.d_inner, |
|
bias=config.conv_bias, |
|
kernel_size=config.d_conv, |
|
groups=config.d_inner, |
|
padding=config.d_conv - 1, |
|
) |
|
|
|
|
|
self.x_proj = nn.Linear( |
|
config.d_inner, config.dt_rank + config.d_state * 2, bias=False |
|
) |
|
|
|
|
|
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) |
|
|
|
A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner) |
|
self.A_log = nn.Parameter(torch.log(A)) |
|
self.D = nn.Parameter(torch.ones(config.d_inner)) |
|
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) |
|
|
|
|
|
def forward(self, x): |
|
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. |
|
|
|
Args: |
|
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...) |
|
|
|
Returns: |
|
output: shape (b, l, d) |
|
|
|
Official Implementation: |
|
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119 |
|
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 |
|
|
|
""" |
|
|
|
(b, l, d) = x.shape |
|
|
|
|
|
x_and_res = self.in_proj(x) |
|
(x, res) = x_and_res.split( |
|
split_size=[self.config.d_inner, self.config.d_inner], dim=-1 |
|
) |
|
|
|
x = rearrange(x, "b l d_in -> b d_in l") |
|
x = self.conv1d(x)[:, :, :l] |
|
x = rearrange(x, "b d_in l -> b l d_in") |
|
|
|
x = F.silu(x) |
|
|
|
y = self.ssm(x) |
|
|
|
y = y * F.silu(res) |
|
|
|
|
|
output = self.out_proj(y) |
|
|
|
return output |
|
|
|
def ssm(self, x): |
|
"""Runs the SSM. See: |
|
- Algorithm 2 in Section 3.2 in the Mamba paper [1] |
|
- run_SSM(A, B, C, u) in The Annotated S4 [2] |
|
|
|
Args: |
|
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) |
|
|
|
Returns: |
|
output: shape (b, l, d_in) |
|
|
|
Official Implementation: |
|
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311 |
|
|
|
""" |
|
(d_in, n) = self.A_log.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
A = -torch.exp(self.A_log.float()) |
|
D = self.D.float() |
|
|
|
x_dbl = self.x_proj(x) |
|
|
|
(delta, B, C) = x_dbl.split( |
|
split_size=[self.config.dt_rank, n, n], dim=-1 |
|
) |
|
delta = F.softplus(self.dt_proj(delta)) |
|
|
|
y = self.selective_scan( |
|
x, delta, A, B, C, D |
|
) |
|
|
|
return y |
|
|
|
def selective_scan(self, u, delta, A, B, C, D): |
|
"""Does selective scan algorithm. See: |
|
- Section 2 State Space Models in the Mamba paper [1] |
|
- Algorithm 2 in Section 3.2 in the Mamba paper [1] |
|
- run_SSM(A, B, C, u) in The Annotated S4 [2] |
|
|
|
This is the classic discrete state space formula: |
|
x(t + 1) = Ax(t) + Bu(t) |
|
y(t) = Cx(t) + Du(t) |
|
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). |
|
|
|
Args: |
|
u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) |
|
delta: shape (b, l, d_in) |
|
A: shape (d_in, n) |
|
B: shape (b, l, n) |
|
C: shape (b, l, n) |
|
D: shape (d_in,) |
|
|
|
Returns: |
|
output: shape (b, l, d_in) |
|
|
|
Official Implementation: |
|
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 |
|
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. |
|
|
|
""" |
|
(b, l, d_in) = u.shape |
|
n = A.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b d_in l n")) |
|
deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b d_in l n") |
|
|
|
|
|
x = torch.zeros((b, d_in, n), device=deltaA.device) |
|
ys = [] |
|
for i in range(l): |
|
x = deltaA[:, :, i] * x + deltaB_u[:, :, i] |
|
y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in") |
|
ys.append(y) |
|
y = torch.stack(ys, dim=1) |
|
|
|
y = y + u * D |
|
|
|
return y |
|
|
|
|
|
class MambaBlock(nn.Module): |
|
def __init__(self, config: MambaConfig, layer_idx: int = 0): |
|
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].""" |
|
super().__init__() |
|
self.config = config |
|
|
|
self.mixer = Mamba(config) |
|
self.norm = MambaRMSNorm(config.d_model) |
|
|
|
def forward(self, x): |
|
return self.mixer(self.norm(x)) + x |
|
|
|
|
|
class MambaPreTrainedModel(PreTrainedModel): |
|
config_class = MambaConfig |
|
base_model_prefix = "backbone" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["MambaBlock"] |
|
|
|
def _init_weights(self, module): |
|
std = 0.02 |
|
if isinstance(module, (nn.Linear, nn.Conv1d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
class MambaModel(MambaPreTrainedModel): |
|
def __init__(self, config: MambaConfig): |
|
"""Full Mamba model. |
|
Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`] |
|
|
|
Args: |
|
config: MambaConfig |
|
""" |
|
super().__init__(config) |
|
|
|
|
|
self.embedding = nn.Embedding(self.config.vocab_size, self.config.d_model) |
|
self.layers = nn.ModuleList( |
|
[MambaBlock(self.config, layer_idx) for layer_idx in range(self.config.n_layer)] |
|
) |
|
self.norm_f = MambaRMSNorm(self.config.d_model) |
|
|
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
output_hidden_states=False, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> BaseModelOutputWithPast: |
|
batch_size = input_ids.shape[0] |
|
hidden_size = self.config.d_model |
|
hidden_states: Tuple[ |
|
torch.Tensor[(batch_size, sequence_length, hidden_size)] |
|
] = () |
|
sequence_length = input_ids.shape[1] |
|
output_hidden_states = output_hidden_states or self.config.output_hidden_states |
|
|
|
last_hidden_state = self.embedding(input_ids) |
|
assert last_hidden_state.shape == ( |
|
batch_size, |
|
sequence_length, |
|
hidden_size, |
|
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" |
|
hidden_states += (last_hidden_state,) |
|
|
|
for layer in self.layers: |
|
last_hidden_state = layer(last_hidden_state) |
|
assert last_hidden_state.shape == ( |
|
batch_size, |
|
sequence_length, |
|
hidden_size, |
|
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" |
|
hidden_states += (last_hidden_state,) |
|
|
|
last_hidden_state = self.norm_f(last_hidden_state) |
|
assert last_hidden_state.shape == ( |
|
batch_size, |
|
sequence_length, |
|
hidden_size, |
|
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}" |
|
hidden_states += (last_hidden_state,) |
|
|
|
assert ( |
|
len(hidden_states) == self.config.n_layer + 2 |
|
), f"{len(hidden_states)} != {self.config.n_layer + 2}" |
|
|
|
return BaseModelOutputWithPast( |
|
hidden_states=hidden_states if output_hidden_states else None, |
|
last_hidden_state=last_hidden_state, |
|
) |
|
|
|
|
|
class MambaModelForCausalLM(MambaPreTrainedModel): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__( |
|
config, |
|
**kwargs, |
|
) |
|
|
|
self.backbone = MambaModel( |
|
config=self.config, |
|
) |
|
|
|
self.lm_head = nn.Linear( |
|
in_features=self.config.d_model, |
|
out_features=self.config.vocab_size, |
|
bias=False, |
|
) |
|
|
|
self.post_init() |
|
|
|
def _tie_weights(self): |
|
self.lm_head.weight = self.backbone.embedding.weight |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_hidden_states=False, |
|
**kwargs, |
|
) -> CausalLMOutputWithPast: |
|
batch_size = input_ids.shape[0] |
|
output_hidden_states = output_hidden_states or self.config.output_hidden_states |
|
sequence_length = input_ids.shape[1] |
|
vocab_size = self.config.vocab_size |
|
|
|
outputs = self.backbone( |
|
input_ids=input_ids, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
last_hidden_state = outputs.last_hidden_state |
|
|
|
logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = ( |
|
self.lm_head( |
|
last_hidden_state, |
|
) |
|
) |
|
|
|
if labels is not None: |
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
else: |
|
loss = None |
|
|
|
return CausalLMOutputWithPast( |
|
hidden_states=outputs.hidden_states if output_hidden_states else None, |
|
logits=logits, |
|
loss=loss, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, attention_mask=None, **model_kwargs |
|
): |
|
return { |
|
"input_ids": input_ids, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|