mamba-130m / modeling_mamba.py
mjschock's picture
Upload model
70327f3 verified
raw
history blame
24 kB
import json
import math
import os
from collections import namedtuple
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from einops import einsum, rearrange, repeat
from torch import FloatTensor, Tensor, nn
from transformers import GenerationMixin, PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
CausalLMOutput,
ImageClassifierOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
)
from trl import PreTrainedModelWrapper
from .configuration_mamba import MambaConfig
# class SwiGLU(nn.Module):
# def forward(self, x, W, V, b, c, beta):
# return F.silu(x * W + b) * (x * V + c)
# Inspired by:
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L31
# - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L177
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/modules/mamba_simple.py#L31
class MambaBlock(nn.Module):
def __init__(self, config: MambaConfig):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1].
Furthermore, in section E.2.2 of the paper, the authors describe the Mamba block as:
"[T]he Mamba block is simply the standard SwiGLU block with an extra conv → SSM path added."
"""
super().__init__()
self.config = config
self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=config.bias)
self.conv1d = nn.Conv1d(
in_channels=config.d_inner,
out_channels=config.d_inner,
bias=config.conv_bias,
kernel_size=config.d_conv,
groups=config.d_inner,
padding=config.d_conv - 1,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(
config.d_inner, config.dt_rank + config.d_state * 2, bias=False
)
# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
A = repeat(torch.arange(1, config.d_state + 1), "n -> d n", d=config.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(config.d_inner))
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
# self.norm = RMSNorm(config.d_model)
def forward(self, x):
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(b, l, d) = x.shape
# x_copy = x # There was a separate class for residual, I deleted that part and added it here.
# x = self.norm(x)
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
(x, res) = x_and_res.split(
split_size=[self.config.d_inner, self.config.d_inner], dim=-1
)
x = rearrange(x, "b l d_in -> b d_in l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d_in l -> b l d_in")
x = F.silu(x)
y = self.ssm(x)
y = y * F.silu(res) # SwiGLU: Swish_β(xW + b) ⊗ (xV + c) => torch.kron(F.silu(xW + b), xV + c) => torch.kron(F.silu(res), y)
output = self.out_proj(y) # output = self.out_proj(y) + x_copy
# "the Mamba block is simply the standard SwiGLU block with an extra 𝖼𝗈𝗇𝗏 → 𝖲𝖲𝖬 path added"
return output
def ssm(self, x):
"""Runs the SSM. See:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
(d_in, n) = self.A_log.shape
# Compute ∆ A B C D, the state space parameters.
# A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
# and is why Mamba is called **selective** state spaces)
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
D = self.D.float()
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
(delta, B, C) = x_dbl.split(
split_size=[self.config.dt_rank, n, n], dim=-1
) # delta: (b, l, dt_rank). B, C: (b, l, n)
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
y = self.selective_scan(
x, delta, A, B, C, D
) # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
return y
def selective_scan(self, u, delta, A, B, C, D):
"""Does selective scan algorithm. See:
- Section 2 State Space Models in the Mamba paper [1]
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
This is the classic discrete state space formula:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
Args:
u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
delta: shape (b, l, d_in)
A: shape (d_in, n)
B: shape (b, l, n)
C: shape (b, l, n)
D: shape (d_in,)
Returns:
output: shape (b, l, d_in)
Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
"""
(b, l, d_in) = u.shape
n = A.shape[1]
# Discretize continuous parameters (A, B)
# - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
# - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
# "A is the more important term and the performance doesn't change much with the simplification on B"
deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n")
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
ys.append(y)
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
y = y + u * D
return y
# Inspired by:
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L19
# - https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L328
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/ops/triton/layernorm.py#L481
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = (
x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
)
return output
class ResidualBlock(
nn.Module
): # Copied and modified from https://github.com/johnma2006/mamba-minimal/blob/03de542a36d873f6e6c4057ad687278cc6ae944d/model.py#L143
def __init__(self, config: MambaConfig):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
# self.args = args
self.mixer = MambaBlock(config)
self.norm = RMSNorm(config.d_model)
def forward(self, x):
"""
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
Note: the official repo chains residual blocks that look like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
where the first Add is a no-op. This is purely for performance reasons as this
allows them to fuse the Add->Norm.
We instead implement our blocks as the more familiar, simpler, and numerically equivalent
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
output = self.mixer(self.norm(x)) + x
return output
# Inspired by:
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L181
# class MambaPretrainedModel(PreTrainedModel, nn.Module):
class MambaPretrainedModel(PreTrainedModel):
r"""
Base class for all models.
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models as well as a few methods common to all models to:
- resize the input embeddings,
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = MambaConfig # TODO: Build on top of MambaConfig?
# base_model_prefix = "backbone"
base_model_prefix = "mamba"
main_input_name = "input_ids"
model_tags = None
_auto_class = None
_no_split_modules = ["MambaBlock"]
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
_keys_to_ignore_on_load_missing = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of
# unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
# warnings.
_keys_to_ignore_on_load_unexpected = None
# a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
# trained, but which are either deterministic or tied variables)
_keys_to_ignore_on_save = None
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
_tied_weights_keys = None
is_parallelizable = False
supports_gradient_checkpointing = True
# Flash Attention 2 support
_supports_flash_attn_2 = False
# SDPA support
_supports_sdpa = False
# Has support for a `Cache` instance as `past_key_values`
_supports_cache_class = False
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
# https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L54
def _init_weights(
self,
module,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * self.config.n_layer)
# def _set_gradient_checkpointing(self, module, value=False):
# if isinstance(module, GPT2Model):
# module.gradient_checkpointing = value
class MambaModel(MambaPretrainedModel):
def __init__(
self, config: MambaConfig = MambaConfig(), **kwargs
) -> None:
"""Full Mamba model.
Mamba model decoder consisting of *config.n_layer* layers. Each layer is a [`MambaBlock`]
Args:
config: MambaConfig
"""
super().__init__(
config,
**kwargs,
)
# self.embedding = nn.Embedding(
# num_embeddings=config.vocab_size,
# embedding_dim=config.d_model,
# )
self.embedding = nn.Embedding(
num_embeddings=config.vocab_size,
embedding_dim=config.d_model,
)
self.layers = nn.ModuleList(
[ResidualBlock(config) for _ in range(self.config.n_layer)]
)
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
# # self.norm_f = RMSNorm(d_model=embedding_dim)
self.norm_f = RMSNorm(config.d_model)
# self.gradient_checkpointing = False
# # self.post_init()
# Initialize weights and apply final processing
self.post_init()
# def _init_weights(self, module):
# std = 0.02
# if isinstance(module, (nn.Linear, nn.Conv1d)):
# module.weight.data.normal_(mean=0.0, std=std)
# if module.bias is not None:
# module.bias.data.zero_()
# elif isinstance(module, nn.Embedding):
# module.weight.data.normal_(mean=0.0, std=std)
# if module.padding_idx is not None:
# module.weight.data[module.padding_idx].zero_()
# Inspired by:
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L198
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L86
# class MambaModel(MambaPretrainedModel):
# def __init__(
# self,
# config: MambaConfig = MambaConfig(),
# **kwargs,
# ) -> None:
# super().__init__(
# config,
# **kwargs,
# )
# self.embedding = nn.Embedding(
# num_embeddings=config.vocab_size,
# embedding_dim=config.d_model,
# )
# # # self.layers = nn.ModuleList(
# # # [ResidualBlock(args=model_args) for _ in range(model_args.n_layer)]
# # # )
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
# # # self.norm_f = RMSNorm(d_model=embedding_dim)
# self.norm_f = RMSNorm(config.d_model)
# # self.gradient_checkpointing = False
# # # self.post_init()
# def get_input_embeddings(self):
# return self.embed_out
# def set_input_embeddings(self, value):
# self.embed_out = value
# def forward(
# self,
# input_ids: torch.LongTensor = None,
# output_hidden_states=False,
# return_dict: Optional[bool] = None,
# **kwargs,
# # ) -> BaseModelOutput:
# ) -> Union[Tuple, BaseModelOutputWithPast]:
# batch_size = input_ids.shape[0]
# hidden_size = self.config.hidden_size
# hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
# sequence_length = input_ids.shape[1]
# output_hidden_states = output_hidden_states or self.config.output_hidden_states
# last_hidden_state = self.embed_out(input_ids)
# assert last_hidden_state.shape == (
# batch_size,
# sequence_length,
# hidden_size,
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
# hidden_states += (last_hidden_state,)
# for layer in self.layers:
# last_hidden_state = layer(last_hidden_state)
# assert last_hidden_state.shape == (
# batch_size,
# sequence_length,
# hidden_size,
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
# hidden_states += (last_hidden_state,)
# last_hidden_state = self.norm_f(last_hidden_state)
# assert last_hidden_state.shape == (
# batch_size,
# sequence_length,
# hidden_size,
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
# hidden_states += (last_hidden_state,)
# assert (
# len(hidden_states) == self.config.n_layer + 2
# ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
# # return BaseModelOutput(
# return BaseModelOutputWithPast(
# hidden_states=hidden_states if output_hidden_states else None,
# last_hidden_state=last_hidden_state,
# )
# Influences:
# - https://huggingface.co/Q-bert/Mamba-130M/blob/f0d00db98acaa62b1ee4304cd11643e69aa62a71/modeling_mamba.py#L238
# - https://github.com/state-spaces/mamba/blob/009bec5ee37f586844a3fc89c040a9c1a9d8badf/mamba_ssm/models/mixer_seq_simple.py#L176
# class MambaModelForCausalLM(MambaModel, GenerationMixin):
# class MambaModelForCausalLM(PreTrainedModel, GenerationMixin):
# class MambaLMHeadModel(MambaPretrainedModel, GenerationMixin):
class MambaLMHeadModel(MambaPretrainedModel):
# _tied_weights_keys = ["lm_head.weight",
def __init__(
self,
config: MambaConfig = MambaConfig(),
**kwargs,
) -> None:
super().__init__(
config,
**kwargs,
)
self.backbone = MambaModel(
config=self.config,
)
self.lm_head = nn.Linear(
in_features=self.config.hidden_size,
out_features=self.config.vocab_size,
bias=False,
)
# # self.head.weight = self.backbone.embedding.weight # TODO: there's some logic in GenerationMix that does this
# Initialize weights and apply final processing
self.post_init()
# # def forward(
# # self, input_ids, output_hidden_states=False, **kwargs
# # ) -> CausalLMOutput:
# # batch_size = input_ids.shape[0]
# # sequence_length = input_ids.shape[1]
# # vocab_size = self.config.vocab_size
# # output_hidden_states = output_hidden_states or self.config.output_hidden_states
# # outputs = self.backbone(
# # input_ids=input_ids,
# # output_hidden_states=output_hidden_states,
# # )
# # last_hidden_state = outputs.last_hidden_state
# # logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
# # self.lm_head(
# # last_hidden_state,
# # )
# # )
# # return CausalLMOutput(
# # hidden_states=outputs.hidden_states if output_hidden_states else None,
# # logits=logits,
# # )
# # def prepare_inputs_for_generation(
# # self, input_ids, attention_mask=None, **model_kwargs
# # ):
# # return {
# # "input_ids": input_ids,
# # }
# class MultimodalMambaModelForCausalLMWithValueHead(PreTrainedModelWrapper):
# lm_head_namings: Tuple[str, str] = ("lm_head", "embed_out")
# transformers_parent_class: transformers.PreTrainedModel = transformers.AutoModelForCausalLM
# # def __init__(
# # self,
# # config: MultimodalMambaConfig = MultimodalMambaConfig(),
# # **kwargs,
# # ) -> None:
# # super().__init__(
# # config,
# # **kwargs,
# # )
# # self.model = MultimodalMambaModelForCausalLM(
# # config=config,
# # )
# # self.value_head = nn.Linear(
# # in_features=config.embedding_dim,
# # out_features=1,
# # bias=False,
# # )
# # def forward(
# # self, input_ids, output_hidden_states=False, **kwargs
# # ) -> CausalLMOutput:
# # outputs = self.model(
# # input_ids=input_ids,
# # output_hidden_states=output_hidden_states,
# # )
# # last_hidden_state = outputs.last_hidden_state
# # value: torch.FloatTensor[batch_size, sequence_length, 1] = self.value_head(
# # last_hidden_state,
# # )
# # return CausalLMOutput(
# # hidden_states=outputs.hidden_states if output_hidden_states else None,
# # logits=outputs.logits,
# # value=value,
# # )