File size: 1,349 Bytes
e3c0725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from transformers import BertModel, BertTokenizerFast
import torch.nn.functional as F

def similarity(embeddings_1, embeddings_2):
    normalized_embeddings_1 = F.normalize(embeddings_1, p=2)
    normalized_embeddings_2 = F.normalize(embeddings_2, p=2)
    return torch.matmul(
        normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1)
    )

class LaBSE:
    def __init__(self):
        self.tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
        self.model = BertModel.from_pretrained("setu4993/LaBSE")
        #self.model = self.model.to('cuda')
        self.model.eval()

    @torch.no_grad()
    def __call__(self, sentences):
        if not isinstance(sentences, list):
            sentences = [sentences]
        tokens = self.tokenizer(sentences, return_tensors="pt", padding=True)
        #print(tokens.input_ids.shape, tokens.token_type_ids.shape, tokens.attention_mask.shape)
        #tokens = tokens.to("cuda")
        outputs = self.model(**tokens)
        embeddings = outputs.pooler_output
        return F.normalize(embeddings, p=2).cpu()#.numpy()

    @property
    def dim(self):
        return 768

if __name__ == "__main__":
    model = LaBSE()
    sents = ["arma virumque cano", "arma gravi numero violentaque bella parabam"]

    output = model(sents)
    print("end", output.shape)