Mamba_561M / modeling_minimamba.py
yagizdevre's picture
added configs
9991887
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_minimamba import MiniMambaConfig
from .model import Mamba2, Mamba2Config
class MiniMamba(PreTrainedModel):
"""
A Hugging Face–style wrapper around a Mamba2 model, providing:
• forward(...) returning a CausalLMOutput
• support for HF training loops
• a naive generate(...) method with top-k/top-p sampling
"""
config_class = MiniMambaConfig # Tells HF which config class to use
def __init__(self, config: MiniMambaConfig) -> None:
"""
Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel.
"""
super().__init__(config)
# If your config includes Mamba2-like parameters, you can build a Mamba2Config from it:
mamba2_args = Mamba2Config(
dim=config.dim,
num_layers=config.num_layers,
num_heads=config.num_heads,
state_dim=config.state_dim,
num_groups=config.num_groups,
conv_size=config.conv_size,
use_mem_eff_path=config.use_mem_eff_path,
dt_bias=config.dt_bias,
D_has_head_dim=config.D_has_head_dim,
learnable_init_states=config.learnable_init_states,
ssm_chunk_size=config.ssm_chunk_size,
vocab_size=config.vocab_size,
ffn_dim_multiplier=config.ffn_dim_multiplier,
multiple_of=config.multiple_of,
norm_eps=config.norm_eps,
init_use_depth=config.init_use_depth,
init_base_std=config.init_base_std,
init_std_factor=config.init_std_factor,
bias=config.bias,
# Torch / training:
seed=config.seed,
# The init_config block nested in JSON:
# Additional Mamba or training fields:
weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False,
torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype,
)
# Internally hold a Mamba2 model
self.mamba = Mamba2(config=mamba2_args)
# Because HF wants the final linear to be part of this top-level model,
# you *can* rely on Mamba2’s built-in embedding + output if you prefer.
# Mamba2 already has self.tok_emb and self.output.
# So we typically do NOT need a separate embedding or lm_head here.
#
# We only do so if we want the “HF standard” tie-weights approach:
# self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
# self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# self.lm_head.weight = self.tok_emb.weight
#
# But Mamba2 does that internally if config.weight_tying == True.
# This is optional: store any device or dtype you might want
self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu'
if isinstance(config.torch_dtype, str):
self.dtype_ = getattr(torch, config.torch_dtype)
else:
self.dtype_ = config.torch_dtype
# Parameter initialization (HF calls them with self._init_weights in some flows).
self.apply(self._init_weights)
print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
def forward(
self,
input_ids: torch.LongTensor,
labels: torch.LongTensor = None,
**kwargs
) -> CausalLMOutput:
"""
Forward pass for causal language modeling.
Returns a CausalLMOutput that includes loss (if labels is provided) and logits.
"""
# Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...)
# but we only need the logits from the simple call:
logits = self.mamba(input_ids) # shape: [batch, seq_len, vocab_size]
loss = None
if labels is not None:
# By default, huggingface GPT-like models shift the logits by one
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
return CausalLMOutput(
loss=loss,
logits=logits,
)
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 50,
temperature: float = 0.5,
top_k: int = 50,
top_p: float = 0.95,
eos_token_id: int = None,
pad_token_id: int = 0,
**kwargs
):
"""
A naive token-by-token generation loop (greedy + top-k/top-p + temperature).
"""
# We'll accumulate new tokens in generated_ids
generated_ids = input_ids.clone()
for _ in range(max_new_tokens):
# Forward pass to get logits for the last token
outputs = self.forward(generated_ids)
logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
# Scale by temperature
if temperature != 1.0:
logits = logits / temperature
# Filter
logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample next token
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # shape: (batch, 1)
# Append
generated_ids = torch.cat([generated_ids, next_token], dim=1)
# If we have an EOS token, we can break early if all sequences have ended
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated_ids
@staticmethod
def top_k_top_p_filtering(
logits: torch.Tensor,
top_k: int = 50,
top_p: float = 0.95,
filter_value: float = float("-inf"),
):
"""
Filters logits using top-k and/or nucleus (top-p) filtering.
"""
# top_k
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
logits[indices_to_remove] = filter_value
# top_p (nucleus)
if 0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift right to keep also the first token above threshold
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Scatter to get back to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def _init_weights(self, module):
"""
HF calls _init_weights to initialize parameters.
If you prefer Mamba’s own init approach, you can call model.mamba.init_weights().
"""
# As an example, we just call Mamba2's init routine for the entire submodel,
# or do some standard PyTorch inits for linear layers, embeddings, etc.
if isinstance(module, Mamba2):
module.init_weights() # Mamba2’s internal init
elif isinstance(module, nn.Linear):
# e.g. standard xavier or normal init
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
# If needed, do your specialized inits for other modules
def _get_num_params(self):
# Count trainable params, subtract duplicates if tying weights, etc.
return sum(p.numel() for p in self.parameters() if p.requires_grad)