|
from collections import defaultdict |
|
from typing import Dict |
|
import datasets |
|
from datasets import Dataset |
|
from sentence_transformers import ( |
|
SentenceTransformer, |
|
SentenceTransformerTrainer, |
|
losses, |
|
evaluation, |
|
SentenceTransformerTrainingArguments |
|
) |
|
from sentence_transformers.models import Transformer, Pooling, Dense, Normalize |
|
|
|
def to_triplets(dataset): |
|
premises = defaultdict(dict) |
|
for sample in dataset: |
|
premises[sample["premise"]][sample["label"]] = sample["hypothesis"] |
|
queries = [] |
|
positives = [] |
|
negatives = [] |
|
for premise, sentences in premises.items(): |
|
if 0 in sentences and 2 in sentences: |
|
queries.append(premise) |
|
positives.append(sentences[0]) |
|
negatives.append(sentences[2]) |
|
return Dataset.from_dict({ |
|
"anchor": queries, |
|
"positive": positives, |
|
"negative": negatives, |
|
}) |
|
|
|
snli_ds = datasets.load_dataset("snli") |
|
snli_ds = datasets.DatasetDict({ |
|
"train": to_triplets(snli_ds["train"]), |
|
"validation": to_triplets(snli_ds["validation"]), |
|
"test": to_triplets(snli_ds["test"]), |
|
}) |
|
multi_nli_ds = datasets.load_dataset("multi_nli") |
|
multi_nli_ds = datasets.DatasetDict({ |
|
"train": to_triplets(multi_nli_ds["train"]), |
|
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]), |
|
}) |
|
|
|
all_nli_ds = datasets.DatasetDict({ |
|
"train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]), |
|
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]), |
|
"test": snli_ds["test"] |
|
}) |
|
|
|
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation") |
|
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test") |
|
|
|
training_args = SentenceTransformerTrainingArguments( |
|
output_dir="checkpoints", |
|
num_train_epochs=1, |
|
seed=42, |
|
per_device_train_batch_size=256, |
|
per_device_eval_batch_size=256, |
|
learning_rate=2e-5, |
|
warmup_ratio=0.1, |
|
bf16=True, |
|
logging_steps=100, |
|
eval_strategy="steps", |
|
eval_steps=100, |
|
save_steps=100, |
|
save_total_limit=2, |
|
metric_for_best_model="sts-dev_spearman_cosine", |
|
greater_is_better=True, |
|
) |
|
|
|
transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384) |
|
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean") |
|
dense = Dense(128, 256) |
|
normalize = Normalize() |
|
model = SentenceTransformer(modules=[transformer, pooling, dense, normalize]) |
|
|
|
for param in model.parameters(): |
|
param.data = param.data.contiguous() |
|
|
|
loss = losses.MultipleNegativesRankingLoss(model) |
|
|
|
|
|
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator( |
|
stsb_dev["sentence1"], |
|
stsb_dev["sentence2"], |
|
[score / 5 for score in stsb_dev["score"]], |
|
main_similarity=evaluation.SimilarityFunction.COSINE, |
|
name="sts-dev", |
|
) |
|
|
|
trainer = SentenceTransformerTrainer( |
|
model=model, |
|
evaluator=dev_evaluator, |
|
args=training_args, |
|
train_dataset=all_nli_ds["train"], |
|
eval_dataset=all_nli_ds["validation"], |
|
loss=loss, |
|
) |
|
trainer.train() |
|
|
|
test_evaluator = evaluation.EmbeddingSimilarityEvaluator( |
|
stsb_test["sentence1"], |
|
stsb_test["sentence2"], |
|
[score / 5 for score in stsb_test["score"]], |
|
main_similarity=evaluation.SimilarityFunction.COSINE, |
|
name="sts-test", |
|
) |
|
results = test_evaluator(model) |
|
|
|
breakpoint() |
|
model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True) |