import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutput from .configuration_minimamba import MiniMambaConfig from .model import Mamba2, Mamba2Config class MiniMamba(PreTrainedModel): """ A Hugging Face–style wrapper around a Mamba2 model, providing: • forward(...) returning a CausalLMOutput • support for HF training loops • a naive generate(...) method with top-k/top-p sampling """ config_class = MiniMambaConfig # Tells HF which config class to use def __init__(self, config: MiniMambaConfig) -> None: """ Initialize the MiniMamba model, bridging Mamba2 with HF's PreTrainedModel. """ super().__init__(config) # If your config includes Mamba2-like parameters, you can build a Mamba2Config from it: mamba2_args = Mamba2Config( dim=config.dim, num_layers=config.num_layers, num_heads=config.num_heads, state_dim=config.state_dim, num_groups=config.num_groups, conv_size=config.conv_size, use_mem_eff_path=config.use_mem_eff_path, dt_bias=config.dt_bias, D_has_head_dim=config.D_has_head_dim, learnable_init_states=config.learnable_init_states, ssm_chunk_size=config.ssm_chunk_size, vocab_size=config.vocab_size, ffn_dim_multiplier=config.ffn_dim_multiplier, multiple_of=config.multiple_of, norm_eps=config.norm_eps, init_use_depth=config.init_use_depth, init_base_std=config.init_base_std, init_std_factor=config.init_std_factor, bias=config.bias, # Torch / training: seed=config.seed, # The init_config block nested in JSON: # Additional Mamba or training fields: weight_tying=config.weight_tying if hasattr(config, "weight_tying") else False, torch_dtype=getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype, ) # Internally hold a Mamba2 model self.mamba = Mamba2(config=mamba2_args) # Because HF wants the final linear to be part of this top-level model, # you *can* rely on Mamba2’s built-in embedding + output if you prefer. # Mamba2 already has self.tok_emb and self.output. # So we typically do NOT need a separate embedding or lm_head here. # # We only do so if we want the “HF standard” tie-weights approach: # self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # self.lm_head.weight = self.tok_emb.weight # # But Mamba2 does that internally if config.weight_tying == True. # This is optional: store any device or dtype you might want self.device_ = 'cuda' if torch.cuda.is_available() else 'cpu' if isinstance(config.torch_dtype, str): self.dtype_ = getattr(torch, config.torch_dtype) else: self.dtype_ = config.torch_dtype # Parameter initialization (HF calls them with self._init_weights in some flows). self.apply(self._init_weights) print("MiniMamba Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,)) def forward( self, input_ids: torch.LongTensor, labels: torch.LongTensor = None, **kwargs ) -> CausalLMOutput: """ Forward pass for causal language modeling. Returns a CausalLMOutput that includes loss (if labels is provided) and logits. """ # Mamba2's forward expects (x: torch.Tensor, target: torch.Tensor|None, ...) # but we only need the logits from the simple call: logits = self.mamba(input_ids) # shape: [batch, seq_len, vocab_size] loss = None if labels is not None: # By default, huggingface GPT-like models shift the logits by one shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) return CausalLMOutput( loss=loss, logits=logits, ) @torch.no_grad() def generate( self, input_ids: torch.LongTensor, max_new_tokens: int = 50, temperature: float = 0.5, top_k: int = 50, top_p: float = 0.95, eos_token_id: int = None, pad_token_id: int = 0, **kwargs ): """ A naive token-by-token generation loop (greedy + top-k/top-p + temperature). """ # We'll accumulate new tokens in generated_ids generated_ids = input_ids.clone() for _ in range(max_new_tokens): # Forward pass to get logits for the last token outputs = self.forward(generated_ids) logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size) # Scale by temperature if temperature != 1.0: logits = logits / temperature # Filter logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample next token probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # shape: (batch, 1) # Append generated_ids = torch.cat([generated_ids, next_token], dim=1) # If we have an EOS token, we can break early if all sequences have ended if eos_token_id is not None and (next_token == eos_token_id).all(): break return generated_ids @staticmethod def top_k_top_p_filtering( logits: torch.Tensor, top_k: int = 50, top_p: float = 0.95, filter_value: float = float("-inf"), ): """ Filters logits using top-k and/or nucleus (top-p) filtering. """ # top_k if top_k > 0: top_k = min(top_k, logits.size(-1)) indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None] logits[indices_to_remove] = filter_value # top_p (nucleus) if 0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift right to keep also the first token above threshold sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False # Scatter to get back to original indexing indices_to_remove = sorted_indices_to_remove.scatter( dim=1, index=sorted_indices, src=sorted_indices_to_remove ) logits[indices_to_remove] = filter_value return logits def _init_weights(self, module): """ HF calls _init_weights to initialize parameters. If you prefer Mamba’s own init approach, you can call model.mamba.init_weights(). """ # As an example, we just call Mamba2's init routine for the entire submodel, # or do some standard PyTorch inits for linear layers, embeddings, etc. if isinstance(module, Mamba2): module.init_weights() # Mamba2’s internal init elif isinstance(module, nn.Linear): # e.g. standard xavier or normal init nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) # If needed, do your specialized inits for other modules def _get_num_params(self): # Count trainable params, subtract duplicates if tying weights, etc. return sum(p.numel() for p in self.parameters() if p.requires_grad)