Mamba_561M / configuration_minimamba.py
yagizdevre's picture
added configs
9991887
from transformers import PretrainedConfig
class MiniMambaConfig(PretrainedConfig):
"""
Minimal or extended config class for MiniMamba.
Inherits from HF's PretrainedConfig so we can do:
model = MiniMamba.from_pretrained(...)
and it will load this config automatically.
This config includes all fields from the provided config.json.
"""
model_type = "minimamba"
def __init__(
self,
# Standard HF fields:
model_type="minimamba",
_name_or_path="Mamba_500M",
architectures=["MiniMamba"],
# Key Mamba architecture hyperparameters:
dim=1024,
num_layers=54,
num_heads=32,
state_dim=128,
num_groups=1,
conv_size=4,
use_mem_eff_path=True,
dt_bias=True,
D_has_head_dim=True,
learnable_init_states=False,
ssm_chunk_size=256,
vocab_size=200064,
ffn_dim_multiplier=2.0,
multiple_of=256,
norm_eps=1e-5,
init_use_depth=False,
init_base_std=None,
init_std_factor="disabled",
hidden_act="silu",
bias=False,
# Torch / training:
torch_dtype="bfloat16",
seed=1337,
# The init_config block nested in JSON:
init_args=None, # e.g. dict with dt_max, dt_min, dt_init_floor, ...
# Additional Mamba or training fields:
seq_len=8192,
weight_tying=True,
dropout=0.0,
num_epochs=1,
global_bsz=524288,
bsz=1,
warmup_steps=1907,
eval_period=50,
save_period=500,
max_lr=0.0003,
min_lr=3e-5,
max_norm=1.0,
dilation=1,
fsdp=True,
ddp=False,
mixed_precision=True,
cpu_offload=False,
sharding_strategy="full_shard",
state_dict_type="full",
auto_wrap_policy="partial",
backward_prefetch="backward_pre",
forward_prefetch=False,
sync_module_states=True,
use_orig_params=True,
device_id=None,
precision=None, # e.g. dict with param="bfloat16", reduce="bfloat16", buffer="bfloat16"
fsdp_modules=None,# e.g. ["MambaBlock"]
use_activation_checkpointing=True,
use_attn=False,
softcap=50.0,
torch_compile=True,
# Now accept arbitrary additional kwargs, to remain flexible:
**kwargs
):
super().__init__(
# In HF, these common keys are typically passed to the parent:
model_type=model_type,
_name_or_path=_name_or_path,
architectures=architectures,
**kwargs
)
self.dim = dim
self.num_layers = num_layers
self.num_heads = num_heads
self.state_dim = state_dim
self.num_groups = num_groups
self.conv_size = conv_size
self.use_mem_eff_path = use_mem_eff_path
self.dt_bias = dt_bias
self.D_has_head_dim = D_has_head_dim
self.learnable_init_states = learnable_init_states
self.ssm_chunk_size = ssm_chunk_size
self.vocab_size = vocab_size
self.ffn_dim_multiplier = ffn_dim_multiplier
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.init_use_depth = init_use_depth
self.init_base_std = init_base_std
self.init_std_factor = init_std_factor
self.hidden_act = hidden_act
self.bias = bias
self.torch_dtype = torch_dtype
self.seed = seed
# Nested init_args (dt_max, dt_min, etc.).
# Could store it as a dict, or parse out the fields individually:
self.init_args = init_args or {}
self.seq_len = seq_len
self.weight_tying = weight_tying
self.dropout = dropout
self.num_epochs = num_epochs
self.global_bsz = global_bsz
self.bsz = bsz
self.warmup_steps = warmup_steps
self.eval_period = eval_period
self.save_period = save_period
self.max_lr = max_lr
self.min_lr = min_lr
self.max_norm = max_norm
self.dilation = dilation
self.fsdp = fsdp
self.ddp = ddp
self.mixed_precision = mixed_precision
self.cpu_offload = cpu_offload
self.sharding_strategy = sharding_strategy
self.state_dict_type = state_dict_type
self.auto_wrap_policy = auto_wrap_policy
self.backward_prefetch = backward_prefetch
self.forward_prefetch = forward_prefetch
self.sync_module_states = sync_module_states
self.use_orig_params = use_orig_params
self.device_id = device_id
self.precision = precision
self.fsdp_modules = fsdp_modules
self.use_activation_checkpointing = use_activation_checkpointing
self.use_attn = use_attn
self.softcap = softcap
self.torch_compile = torch_compile
# If you want to store any leftover kwargs:
self.extra_args = kwargs