File size: 3,659 Bytes
2fd38a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from collections import defaultdict
from typing import Dict
import datasets
from datasets import Dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
    evaluation,
    SentenceTransformerTrainingArguments
)
from sentence_transformers.models import Transformer, Pooling, Dense, Normalize

def to_triplets(dataset):
    premises = defaultdict(dict)
    for sample in dataset:
        premises[sample["premise"]][sample["label"]] = sample["hypothesis"]
    queries = []
    positives = []
    negatives = []
    for premise, sentences in premises.items():
        if 0 in sentences and 2 in sentences:
            queries.append(premise)
            positives.append(sentences[0]) # <- entailment
            negatives.append(sentences[2]) # <- contradiction
    return Dataset.from_dict({
        "anchor": queries,
        "positive": positives,
        "negative": negatives,
    })

snli_ds = datasets.load_dataset("snli")
snli_ds = datasets.DatasetDict({
    "train": to_triplets(snli_ds["train"]),
    "validation": to_triplets(snli_ds["validation"]),
    "test": to_triplets(snli_ds["test"]),
})
multi_nli_ds = datasets.load_dataset("multi_nli")
multi_nli_ds = datasets.DatasetDict({
    "train": to_triplets(multi_nli_ds["train"]),
    "validation_matched": to_triplets(multi_nli_ds["validation_matched"]),
})

all_nli_ds = datasets.DatasetDict({
    "train": datasets.concatenate_datasets([snli_ds["train"], multi_nli_ds["train"]]),#.select(range(10000)),
    "validation": datasets.concatenate_datasets([snli_ds["validation"], multi_nli_ds["validation_matched"]]),#.select(range(1000)),
    "test": snli_ds["test"]
})

stsb_dev = datasets.load_dataset("mteb/stsbenchmark-sts", split="validation")
stsb_test = datasets.load_dataset("mteb/stsbenchmark-sts", split="test")

training_args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    num_train_epochs=1,
    seed=42,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    bf16=True,
    logging_steps=100,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    metric_for_best_model="sts-dev_spearman_cosine",
    greater_is_better=True,
)

transformer = Transformer("prajjwal1/bert-tiny", max_seq_length=384)
pooling = Pooling(transformer.get_word_embedding_dimension(), pooling_mode="mean")
dense = Dense(128, 256)
normalize = Normalize()
model = SentenceTransformer(modules=[transformer, pooling, dense, normalize])
# Ensure all tensors in the model are contiguous
for param in model.parameters():
    param.data = param.data.contiguous()

loss = losses.MultipleNegativesRankingLoss(model)
# loss = losses.MatryoshkaLoss(model, loss, [256, 128, 64, 32, 16, 8])

dev_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_dev["sentence1"],
    stsb_dev["sentence2"],
    [score / 5 for score in stsb_dev["score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-dev",
)

trainer = SentenceTransformerTrainer(
    model=model,
    evaluator=dev_evaluator,
    args=training_args,
    train_dataset=all_nli_ds["train"],
    eval_dataset=all_nli_ds["validation"],
    loss=loss,
)
trainer.train()

test_evaluator = evaluation.EmbeddingSimilarityEvaluator(
    stsb_test["sentence1"],
    stsb_test["sentence2"],
    [score / 5 for score in stsb_test["score"]],
    main_similarity=evaluation.SimilarityFunction.COSINE,
    name="sts-test",
)
results = test_evaluator(model)

breakpoint()
model.push_to_hub("sentence-transformers-testing/all-nli-bert-tiny-dense", private=True)