File size: 3,659 Bytes
2fd38a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
losses,
evaluation,
SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling, Dense, Normalize
def to_triplets(dataset):
premises = defaultdict(dict)
for sample in dataset:
premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
queries = []
positives = []
negatives = []
for premise, sentences in premises.items():
if 0 in sentences and 2 in sentences:
queries.append(premise)
positives.append(sentences[0]) # <- entailment
negatives.append(sentences[2]) # <- contradiction
return Dataset.from_dict({
"anchor": queries,
"positive": positives,
"negative": negatives,
})
snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
"train": to_triplets(snli_ds["train"]),
"validation": to_triplets(snli_ds["validation"]),
"test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
"train": to_triplets(multi_nli_ds["train"]),
"validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})
all_nli_ds = datasets.DatasetDict({
"train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),#.select(range(10000)),
"validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),#.select(range(1000)),
"test": snli_ds["test"]
})
stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")
training_args = SentenceTransformerTrainingArguments(
output_dir="checkpoints",
num_train_epochs=1,
seed=42,
per_device_train_batch_size=256,
per_device_eval_batch_size=256,
learning_rate=2e-5,
warmup_ratio=0.1,
bf16=True,
logging_steps=100,
eval_strategy="steps",
eval_steps=100,
save_steps=100,
save_total_limit=2,
metric_for_best_model="sts-dev_spearman_cosine",
greater_is_better=True,
)
transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
dense = Dense(128, 256)
normalize = Normalize()
model = SentenceTransformer(modules=[transformer, pooling, dense, normalize])
# Ensure all tensors in the model are contiguous
for param in model.parameters():
param.data = param.data.contiguous()
loss = losses.MultipleNegativesRankingLoss(model)
# loss = losses.MatryoshkaLoss(model, loss, [256, 128, 64, 32, 16, 8])
dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_dev["sentence1"],
stsb_dev["sentence2"],
[score / 5 for score in stsb_dev["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-dev",
)
trainer = SentenceTransformerTrainer(
model=model,
evaluator=dev_evaluator,
args=training_args,
train_dataset=all_nli_ds["train"],
eval_dataset=all_nli_ds["validation"],
loss=loss,
)
trainer.train()
test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
stsb_test["sentence1"],
stsb_test["sentence2"],
[score / 5 for score in stsb_test["score"]],
main_similarity=evaluation.SimilarityFunction.COSINE,
name="sts-test",
)
results = test_evaluator(model)
breakpoint()
model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True) |