gte-mean_long-ctx_multi / training_config.yml
mlconti's picture
Upload folder using huggingface_hub
87bac30 verified
multi_ctx_training: True
transformer_layer:
(): sentence_transformers.models.Transformer
model_name_or_path: "./models/gte-modernbert-base"
pooling_layer:
(): sentence_transformers.models.Pooling
word_embedding_dimension: 768
pooling_mode: "mean"
base_model:
(): sentence_transformers.SentenceTransformer
modules:
- !cfg transformer_layer
- !cfg pooling_layer
model_kwargs:
attn_implementation: "flash_attention_2"
torch_dtype: !ext torch.bfloat16
config:
(): contextual_embeddings.trainer.contextual_training.ContextualTrainingConfig
model:
(): contextual_embeddings.models.long_context_model.LongContextEmbeddingModel
base_model: !cfg base_model # points to the variable defined above
multi_ctx_training: !cfg multi_ctx_training
multi_ctx_training: !cfg multi_ctx_training # passed to both model and trainer
exp_name: "gte-mean_long-ctx_multi"
n_gpus: 4
output_dir: "./checkpoints/final_experiments"
train_dataset:
(): contextual_embeddings.models.utils.get_long_context_dataset # function returning the dataset
base_model: !cfg base_model
eval_dataset:
mldr:
(): contextual_embeddings.models.utils.get_chunked_mldr_split
path: "data_dir/chunked-mldr-big"
split: "test"
base_model: !cfg base_model
squad:
(): contextual_embeddings.models.utils.get_chunked_mldr_split
path: "data_dir/squad"
split: "validation"
base_model: !cfg base_model
all_queries: False
narrative_qa:
(): contextual_embeddings.models.utils.get_chunked_mldr_split
path: "data_dir/narrative_qa"
split: "test"
base_model: !cfg base_model
all_queries: False
run_train: True
evaluator:
(): contextual_embeddings.evaluators.NanoBEIRLocalEvaluator.NanoBEIRLocalEvaluator
dataset_names:
- "fever"
- "msmarco"
- "hotpotqa"
accuracy_at_k: [1, 5]
precision_recall_at_k: [1, 5]
ndcg_at_k: [1, 5]
training_args:
(): sentence_transformers.SentenceTransformerTrainingArguments
output_dir: null
overwrite_output_dir: true
num_train_epochs: 2
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
fp16: False # Set to False if you get an error that your GPU can't run on FP16
bf16: True # Set to True if you have a GPU that supports BF16
learning_rate: 1e-4
warmup_steps: 55
lr_scheduler_type: "cosine"
eval_strategy: "steps"
eval_on_start: True
eval_steps: 100
logging_steps: 10 # how often to log to W&B
report_to: "wandb"