from transformers import PretrainedConfig class MiniMambaConfig(PretrainedConfig): """ Minimal or extended config class for MiniMamba. Inherits from HF's PretrainedConfig so we can do: model = MiniMamba.from_pretrained(...) and it will load this config automatically. This config includes all fields from the provided config.json. """ model_type = "minimamba" def __init__( self, # Standard HF fields: model_type="minimamba", _name_or_path="Mamba_500M", architectures=["MiniMamba"], # Key Mamba architecture hyperparameters: dim=1024, num_layers=54, num_heads=32, state_dim=128, num_groups=1, conv_size=4, use_mem_eff_path=True, dt_bias=True, D_has_head_dim=True, learnable_init_states=False, ssm_chunk_size=256, vocab_size=200064, ffn_dim_multiplier=2.0, multiple_of=256, norm_eps=1e-5, init_use_depth=False, init_base_std=None, init_std_factor="disabled", hidden_act="silu", bias=False, # Torch / training: torch_dtype="bfloat16", seed=1337, # The init_config block nested in JSON: init_args=None, # e.g. dict with dt_max, dt_min, dt_init_floor, ... # Additional Mamba or training fields: seq_len=8192, weight_tying=True, dropout=0.0, num_epochs=1, global_bsz=524288, bsz=1, warmup_steps=1907, eval_period=50, save_period=500, max_lr=0.0003, min_lr=3e-5, max_norm=1.0, dilation=1, fsdp=True, ddp=False, mixed_precision=True, cpu_offload=False, sharding_strategy="full_shard", state_dict_type="full", auto_wrap_policy="partial", backward_prefetch="backward_pre", forward_prefetch=False, sync_module_states=True, use_orig_params=True, device_id=None, precision=None, # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16" fsdp_modules=None,# e.g. ["MambaBlock"] use_activation_checkpointing=True, use_attn=False, softcap=50.0, torch_compile=True, # Now accept arbitrary additional kwargs, to remain flexible: **kwargs ): super().__init__( # In HF, these common keys are typically passed to the parent: model_type=model_type, _name_or_path=_name_or_path, architectures=architectures, **kwargs ) self.dim = dim self.num_layers = num_layers self.num_heads = num_heads self.state_dim = state_dim self.num_groups = num_groups self.conv_size = conv_size self.use_mem_eff_path = use_mem_eff_path self.dt_bias = dt_bias self.D_has_head_dim = D_has_head_dim self.learnable_init_states = learnable_init_states self.ssm_chunk_size = ssm_chunk_size self.vocab_size = vocab_size self.ffn_dim_multiplier = ffn_dim_multiplier self.multiple_of = multiple_of self.norm_eps = norm_eps self.init_use_depth = init_use_depth self.init_base_std = init_base_std self.init_std_factor = init_std_factor self.hidden_act = hidden_act self.bias = bias self.torch_dtype = torch_dtype self.seed = seed # Nested init_args (dt_max, dt_min, etc.). # Could store it as a dict, or parse out the fields individually: self.init_args = init_args or {} self.seq_len = seq_len self.weight_tying = weight_tying self.dropout = dropout self.num_epochs = num_epochs self.global_bsz = global_bsz self.bsz = bsz self.warmup_steps = warmup_steps self.eval_period = eval_period self.save_period = save_period self.max_lr = max_lr self.min_lr = min_lr self.max_norm = max_norm self.dilation = dilation self.fsdp = fsdp self.ddp = ddp self.mixed_precision = mixed_precision self.cpu_offload = cpu_offload self.sharding_strategy = sharding_strategy self.state_dict_type = state_dict_type self.auto_wrap_policy = auto_wrap_policy self.backward_prefetch = backward_prefetch self.forward_prefetch = forward_prefetch self.sync_module_states = sync_module_states self.use_orig_params = use_orig_params self.device_id = device_id self.precision = precision self.fsdp_modules = fsdp_modules self.use_activation_checkpointing = use_activation_checkpointing self.use_attn = use_attn self.softcap = softcap self.torch_compile = torch_compile # If you want to store any leftover kwargs: self.extra_args = kwargs