mamba-130m / configuration_mamba.py
mjschock's picture
Upload config
5585048 verified
raw
history blame
1.3 kB
import math
from typing import Union
from transformers import PretrainedConfig
class MambaConfig(PretrainedConfig):
model_type = "mamba"
def __init__(
self,
vocab_size=50277,
d_state=16,
d_model=2560,
d_conv=4,
expand=2,
conv_bias=True,
bias=False,
n_layer=64,
dt_rank: Union[int, str] = "auto",
pad_vocab_size_multiple=8,
initializer_range=0.02,
**kwargs,
):
self.vocab_size = vocab_size
self.n_layer = n_layer
self.conv_bias = conv_bias
self.expand = expand
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.d_conv = d_conv
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = dt_rank
self.initializer_range = initializer_range
self.bias = bias
if self.dt_rank == "auto":
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (
self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple
)
super().__init__(
**kwargs,
)