axial_caducues_1200 / configuration_caduceus.py
emarro's picture
Upload AxialCaduceusForMaskedLM
0215062 verified
"""Caduceus config for Hugging Face.
"""
from typing import Optional, Union
from transformers import PretrainedConfig
class CaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
d_intermediate: int = 0,
use_mamba2: bool = False,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
pos_embeddings: Optional[str] = None,
row_first: Optional[bool] = True,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.d_intermediate = d_intermediate
self.use_mamba2 = use_mamba2
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map
self.pos_embeddings = pos_embeddings
self.row_first = row_first
class AxialCaduceusConfig(PretrainedConfig):
"""Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
model_type = "axial_caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
d_intermediate: int = 0,
use_mamba2: bool = False,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
pos_embeddings: Optional[str] = None,
row_first: Optional[bool] = True,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.d_intermediate = d_intermediate
self.use_mamba2 = use_mamba2
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map
self.pos_embeddings = pos_embeddings
self.row_first = row_first
class MixedCaduceusConfig(PretrainedConfig):
"""Config that extends the original CaduceusConfig with params relevant to alternating between attention and caducues"""
model_type = "mixed_caduceus"
def __init__(
self,
# From original MambaConfig
d_model: int = 2560,
d_intermediate: int = 0,
use_mamba2: bool = False,
n_layer: int = 64,
vocab_size: int = 50277,
ssm_cfg: Optional[dict] = None,
rms_norm: bool = True,
residual_in_fp32: bool = True,
fused_add_norm: bool = True,
pad_vocab_size_multiple: int = 8,
# Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
norm_epsilon: float = 1e-5,
# Used in init_weights
initializer_cfg: Optional[dict] = None,
# Caduceus-specific params
bidirectional: bool = True,
bidirectional_strategy: Union[str, None] = "add",
bidirectional_weight_tie: bool = True,
rcps: bool = False,
complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
# attention specific params
attn_d_model: int = 128,
attn_n_heads: int = 16,
attn_attn_dropout: float = 0.1,
attn_block_dropout: float = 0.1,
**kwargs,
):
super().__init__(**kwargs)
self.d_model = d_model
self.d_intermediate = d_intermediate
self.use_mamba2 = use_mamba2
self.n_layer = n_layer
self.vocab_size = vocab_size
self.ssm_cfg = ssm_cfg
self.rms_norm = rms_norm
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.norm_epsilon = norm_epsilon
self.initializer_cfg = initializer_cfg
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.rcps = rcps
self.complement_map = complement_map
self.attn_d_model = attn_d_model
self.attn_n_heads = attn_n_heads
self.attn_attn_dropout = attn_attn_dropout
self.attn_block_dropout = attn_block_dropout