from typing import Optional, Dict from transformers import AutoConfig from transformers.configuration_utils import PretrainedConfig from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.m2m_100.configuration_m2m_100 import M2M100Config NLLBLLM2VEC_TYPE = "nllb-llm2vec" DEFAULT_M2M100_CONFIG = { "activation_dropout": 0.0, "activation_function": "relu", "architectures": ["M2M100Encoder"], "attention_dropout": 0.1, "bos_token_id": 0, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 4096, "decoder_layerdrop": 0, "decoder_layers": 12, "decoder_start_token_id": 2, "dropout": 0.1, "encoder_attention_heads": 16, "encoder_ffn_dim": 4096, "encoder_layerdrop": 0, "encoder_layers": 12, "eos_token_id": 2, "init_std": 0.02, "is_encoder_decoder": True, "max_position_embeddings": 1024, "model_type": "m2m_100", "num_hidden_layers": 12, "pad_token_id": 1, "scale_embedding": True, "torch_dtype": "float32", "transformers_version": "4.21.0.dev0", "use_cache": True, "vocab_size": 256206, "tokenizer_class": "NllbTokenizer", "max_length": 200, "_attn_implementation": "flash_attention_2", } DEFAULT_LLAMA_CONFIG = { "attention_bias": False, "attention_dropout": 0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 0.00001, "rope_scaling": None, "rope_theta": 500000, "tie_word_embeddings": False, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": False, "vocab_size": 128256, "_attn_implementation": "flash_attention_2", } class NLLBLLM2VecConfig(PretrainedConfig): model_type = "nllb-llm2vec" is_composition = False def __init__( self, nllb_config: Dict = DEFAULT_M2M100_CONFIG, llm2vec_config: Dict = DEFAULT_LLAMA_CONFIG, _attn_implementation="sdpa", initializer_range: Optional[float] = None, **kwargs, ): super().__init__(**kwargs) self._attn_implementation = _attn_implementation self.nllb_config = M2M100Config(**nllb_config) self.nllb_config._attn_implementation = _attn_implementation self.llm2vec_config = LlamaConfig(**llm2vec_config) self.llm2vec_config._attn_implementation = _attn_implementation if initializer_range is None: self.initializer_range = self.llm2vec_config.initializer_range else: self.initializer_range = initializer_range self.llm2vec_config.initializer_range AutoConfig.register(NLLBLLM2VEC_TYPE, NLLBLLM2VecConfig) NLLBLLM2VecConfig.register_for_auto_class()