|
from transformers import PretrainedConfig |
|
|
|
class MiniMambaConfig(PretrainedConfig): |
|
""" |
|
Minimal or extended config class for MiniMamba. |
|
Inherits from HF's PretrainedConfig so we can do: |
|
model = MiniMamba.from_pretrained(...) |
|
and it will load this config automatically. |
|
|
|
This config includes all fields from the provided config.json. |
|
""" |
|
model_type = "minimamba" |
|
|
|
def __init__( |
|
self, |
|
|
|
model_type="minimamba", |
|
_name_or_path="Mamba_500M", |
|
architectures=["MiniMamba"], |
|
|
|
|
|
dim=1024, |
|
num_layers=54, |
|
num_heads=32, |
|
state_dim=128, |
|
num_groups=1, |
|
conv_size=4, |
|
use_mem_eff_path=True, |
|
dt_bias=True, |
|
D_has_head_dim=True, |
|
learnable_init_states=False, |
|
ssm_chunk_size=256, |
|
vocab_size=200064, |
|
ffn_dim_multiplier=2.0, |
|
multiple_of=256, |
|
norm_eps=1e-5, |
|
init_use_depth=False, |
|
init_base_std=None, |
|
init_std_factor="disabled", |
|
hidden_act="silu", |
|
bias=False, |
|
|
|
|
|
torch_dtype="bfloat16", |
|
seed=1337, |
|
|
|
|
|
init_args=None, |
|
|
|
|
|
seq_len=8192, |
|
weight_tying=True, |
|
dropout=0.0, |
|
num_epochs=1, |
|
global_bsz=524288, |
|
bsz=1, |
|
warmup_steps=1907, |
|
eval_period=50, |
|
save_period=500, |
|
max_lr=0.0003, |
|
min_lr=3e-5, |
|
max_norm=1.0, |
|
dilation=1, |
|
fsdp=True, |
|
ddp=False, |
|
mixed_precision=True, |
|
cpu_offload=False, |
|
sharding_strategy="full_shard", |
|
state_dict_type="full", |
|
auto_wrap_policy="partial", |
|
backward_prefetch="backward_pre", |
|
forward_prefetch=False, |
|
sync_module_states=True, |
|
use_orig_params=True, |
|
device_id=None, |
|
precision=None, |
|
fsdp_modules=None, |
|
use_activation_checkpointing=True, |
|
use_attn=False, |
|
softcap=50.0, |
|
torch_compile=True, |
|
|
|
|
|
**kwargs |
|
): |
|
super().__init__( |
|
|
|
model_type=model_type, |
|
_name_or_path=_name_or_path, |
|
architectures=architectures, |
|
**kwargs |
|
) |
|
|
|
self.dim = dim |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.state_dim = state_dim |
|
self.num_groups = num_groups |
|
self.conv_size = conv_size |
|
self.use_mem_eff_path = use_mem_eff_path |
|
self.dt_bias = dt_bias |
|
self.D_has_head_dim = D_has_head_dim |
|
self.learnable_init_states = learnable_init_states |
|
self.ssm_chunk_size = ssm_chunk_size |
|
self.vocab_size = vocab_size |
|
self.ffn_dim_multiplier = ffn_dim_multiplier |
|
self.multiple_of = multiple_of |
|
self.norm_eps = norm_eps |
|
self.init_use_depth = init_use_depth |
|
self.init_base_std = init_base_std |
|
self.init_std_factor = init_std_factor |
|
self.hidden_act = hidden_act |
|
self.bias = bias |
|
|
|
self.torch_dtype = torch_dtype |
|
self.seed = seed |
|
|
|
|
|
|
|
self.init_args = init_args or {} |
|
|
|
self.seq_len = seq_len |
|
self.weight_tying = weight_tying |
|
self.dropout = dropout |
|
self.num_epochs = num_epochs |
|
self.global_bsz = global_bsz |
|
self.bsz = bsz |
|
self.warmup_steps = warmup_steps |
|
self.eval_period = eval_period |
|
self.save_period = save_period |
|
self.max_lr = max_lr |
|
self.min_lr = min_lr |
|
self.max_norm = max_norm |
|
self.dilation = dilation |
|
self.fsdp = fsdp |
|
self.ddp = ddp |
|
self.mixed_precision = mixed_precision |
|
self.cpu_offload = cpu_offload |
|
self.sharding_strategy = sharding_strategy |
|
self.state_dict_type = state_dict_type |
|
self.auto_wrap_policy = auto_wrap_policy |
|
self.backward_prefetch = backward_prefetch |
|
self.forward_prefetch = forward_prefetch |
|
self.sync_module_states = sync_module_states |
|
self.use_orig_params = use_orig_params |
|
self.device_id = device_id |
|
self.precision = precision |
|
self.fsdp_modules = fsdp_modules |
|
self.use_activation_checkpointing = use_activation_checkpointing |
|
self.use_attn = use_attn |
|
self.softcap = softcap |
|
self.torch_compile = torch_compile |
|
|
|
|
|
self.extra_args = kwargs |
|
|