|
import math |
|
from typing import Union |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class MambaConfig(PretrainedConfig): |
|
model_type = "mamba" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=50277, |
|
d_state=16, |
|
d_model=2560, |
|
d_conv=4, |
|
expand=2, |
|
conv_bias=True, |
|
bias=False, |
|
n_layer=64, |
|
dt_rank: Union[int, str] = "auto", |
|
pad_vocab_size_multiple=8, |
|
initializer_range=0.02, |
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
self.n_layer = n_layer |
|
self.conv_bias = conv_bias |
|
self.expand = expand |
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple |
|
self.d_conv = d_conv |
|
self.d_model = d_model |
|
self.d_state = d_state |
|
self.d_inner = int(self.expand * self.d_model) |
|
self.dt_rank = dt_rank |
|
self.initializer_range = initializer_range |
|
self.bias = bias |
|
|
|
if self.dt_rank == "auto": |
|
self.dt_rank = math.ceil(self.d_model / 16) |
|
|
|
if self.vocab_size % self.pad_vocab_size_multiple != 0: |
|
self.vocab_size += ( |
|
self.pad_vocab_size_multiple |
|
- self.vocab_size % self.pad_vocab_size_multiple |
|
) |
|
|
|
super().__init__( |
|
**kwargs, |
|
) |
|
|