|
import math |
|
from typing import Union |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MambaConfig(PretrainedConfig): |
|
model_type: str = "mamba" |
|
|
|
def __init__( |
|
self, |
|
bias: bool = False, |
|
conv_bias: bool = True, |
|
d_conv: int = 4, |
|
d_model: int = 2560, |
|
d_state: int = 16, |
|
dt_rank: Union[int, str] = "auto", |
|
expand: int = 2, |
|
|
|
|
|
n_layer: int = 64, |
|
pad_vocab_size_multiple: int = 8, |
|
|
|
|
|
|
|
vocab_size: int = 50277, |
|
**kwargs, |
|
): |
|
self.bias = bias |
|
self.conv_bias = conv_bias |
|
self.d_conv = d_conv |
|
self.d_model = d_model |
|
self.d_state = d_state |
|
self.dt_rank = dt_rank |
|
self.expand = expand |
|
self.n_layer = n_layer |
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple |
|
self.vocab_size = vocab_size |
|
|
|
self.d_inner = int(self.expand * self.d_model) |
|
|
|
if self.dt_rank == "auto": |
|
self.dt_rank = math.ceil(self.d_model / 16) |
|
|
|
if self.vocab_size % self.pad_vocab_size_multiple != 0: |
|
self.vocab_size += ( |
|
self.pad_vocab_size_multiple |
|
- self.vocab_size % self.pad_vocab_size_multiple |
|
) |
|
|
|
|
|
|
|
self.hidden_size = self.d_model |
|
|
|
super().__init__( |
|
**kwargs, |
|
) |
|
|