import math from typing import Union from transformers import PretrainedConfig class MambaConfig(PretrainedConfig): model_type = "mamba" def __init__( self, vocab_size=50277, d_state=16, d_model=2560, d_conv=4, expand=2, conv_bias=True, bias=False, n_layer=64, norm_epsilon=1e-5, dt_rank: Union[int, str] = "auto", pad_vocab_size_multiple=8, initializer_range=0.02, rms_norm: bool = True, fused_add_norm: bool = True, ssm_cfg={}, residual_in_fp32: bool = True, **kwargs, ): self.vocab_size = vocab_size self.n_layer = n_layer self.conv_bias = conv_bias self.expand = expand self.pad_vocab_size_multiple = pad_vocab_size_multiple self.d_conv = d_conv self.d_model = d_model self.d_state = d_state self.d_inner = int(self.expand * self.d_model) self.dt_rank = dt_rank self.initializer_range = initializer_range self.bias = bias self.ssm_cfg = ssm_cfg self.norm_epsilon = norm_epsilon self.rms_norm = rms_norm self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm if self.dt_rank == "auto": self.dt_rank = math.ceil(self.d_model / 16) if self.vocab_size % self.pad_vocab_size_multiple != 0: self.vocab_size += ( self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple ) super().__init__( **kwargs, )