all-nli-bert-tiny-dense / train_script.py
tomaarsen's picture
tomaarsen HF staff
Create train_script.py
2fd38a5 verified
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
losses,
evaluation,
SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling, Dense, Normalize
def to_triplets(dataset):
premises = defaultdict(dict)
for sample in dataset:
premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
queries = []
positives = []
negatives = []
for premise, sentences in premises.items():
if 0 in sentences and 2 in sentences:
queries.append(premise)
positives.append(sentences[0]) # <- entailment
negatives.append(sentences[2]) # <- contradiction
return Dataset.from_dict({
"anchor": queries,
"positive": positives,
"negative": negatives,
})
snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
"train": to_triplets(snli_ds["train"]),
"validation": to_triplets(snli_ds["validation"]),
"test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
"train": to_triplets(multi_nli_ds["train"]),
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})
all_nli_ds = datasets.DatasetDict({
"train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),#.select(range(10000)),
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),#.select(range(1000)),
"test": snli_ds["test"]
})
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
training_args = SentenceTransformerTrainingArguments(
output_dir="checkpoints",
num_train_epochs=1,
seed=42,
per_device_train_batch_size=256,
per_device_eval_batch_size=256,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=100,
eval_strategy="steps",
eval_steps=100,
save_steps=100,
save_total_limit=2,
metric_for_best_model="sts-dev_spearman_cosine",
greater_is_better=True,
)
transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
dense = Dense(128, 256)
normalize = Normalize()
model = SentenceTransformer(modules=[transformer, pooling, dense, normalize])
# Ensure all tensors in the model are contiguous
for param in model.parameters():
param.data = param.data.contiguous()
loss = losses.MultipleNegativesRankingLoss(model)
# loss = losses.MatryoshkaLoss(model, loss, [256, 128, 64, 32, 16, 8])
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_dev["sentence1"],
stsb_dev["sentence2"],
[score / 5 for score in stsb_dev["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-dev",
)
trainer = SentenceTransformerTrainer(
model=model,
evaluator=dev_evaluator,
args=training_args,
train_dataset=all_nli_ds["train"],
eval_dataset=all_nli_ds["validation"],
loss=loss,
)
trainer.train()
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_test["sentence1"],
stsb_test["sentence2"],
[score / 5 for score in stsb_test["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-test",
)
results = test_evaluator(model)
breakpoint()
model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True)