Safetensors
bi-encoder
bert-bi-encoder / configs /fine-tune.yaml
fschlatt's picture
update readme
be1a6a8
# lightning.pytorch==2.3.3
seed_everything: 0
trainer:
precision: bf16-mixed
max_steps: 50000
data:
class_path: lightning_ir.LightningIRDataModule
init_args:
num_workers: 1
train_batch_size: 64
shuffle_train: true
train_dataset:
class_path: lightning_ir.RunDataset
init_args:
run_path_or_id: msmarco-passage/train/rank-distillm/set-encoder
depth: 100
sample_size: 8
sampling_strategy: log_random
targets: score
normalize_targets: false
model:
class_path: lightning_ir.BiEncoderModule
init_args:
model_name_or_path: bert-base-uncased
config:
class_path: lightning_ir.BiEncoderConfig
init_args:
similarity_function: dot
query_expansion: false
attend_to_query_expanded_tokens: false
query_pooling_strategy: mean
query_mask_scoring_tokens: null
query_aggregation_function: sum
doc_expansion: false
attend_to_doc_expanded_tokens: false
doc_pooling_strategy: mean
doc_mask_scoring_tokens: null
normalize: false
sparsification: null
add_marker_tokens: false
embedding_dim: 768
projection: null
query_length: 32
doc_length: 256
loss_functions:
- class_path: lightning_ir.SupervisedMarginMSE
- class_path: lightning_ir.KLDivergence
- class_path: lightning_ir.InBatchCrossEntropy
init_args:
pos_sampling_technique: first
neg_sampling_technique: first
max_num_neg_samples: 8
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 2.0e-05
lr_scheduler:
class_path: lightning_ir.LinearLRSchedulerWithLinearWarmup
init_args:
num_warmup_steps: 5000
final_value: 0.02
num_delay_steps: 0